Commit 32e4ca51 authored by qianyj's avatar qianyj
Browse files

Update code to v2.11.0

parents 9485aa1d 71060f67
This directory contains binaries and utils required for input preprocessing,
tokenization, etc that can be used with model building blocks available in
NLP modeling library [nlp/modelling](https://github.com/tensorflow/models/tree/master/official/nlp/modeling)
to train custom models and validate new research ideas.
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -24,7 +24,7 @@ from absl import logging ...@@ -24,7 +24,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
from official.nlp.bert import tokenization from official.nlp.tools import tokenization
class InputExample(object): class InputExample(object):
...@@ -187,6 +187,8 @@ class AxProcessor(DataProcessor): ...@@ -187,6 +187,8 @@ class AxProcessor(DataProcessor):
def _create_examples_tfds(self, dataset, set_type): def _create_examples_tfds(self, dataset, set_type):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = list(dataset)
dataset.sort(key=lambda x: x["idx"])
examples = [] examples = []
for i, example in enumerate(dataset): for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -218,6 +220,8 @@ class ColaProcessor(DefaultGLUEDataProcessor): ...@@ -218,6 +220,8 @@ class ColaProcessor(DefaultGLUEDataProcessor):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load( dataset = tfds.load(
"glue/cola", split=set_type, try_gcs=True).as_numpy_iterator() "glue/cola", split=set_type, try_gcs=True).as_numpy_iterator()
dataset = list(dataset)
dataset.sort(key=lambda x: x["idx"])
examples = [] examples = []
for i, example in enumerate(dataset): for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -312,6 +316,8 @@ class MnliProcessor(DataProcessor): ...@@ -312,6 +316,8 @@ class MnliProcessor(DataProcessor):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load( dataset = tfds.load(
"glue/mnli", split=set_type, try_gcs=True).as_numpy_iterator() "glue/mnli", split=set_type, try_gcs=True).as_numpy_iterator()
dataset = list(dataset)
dataset.sort(key=lambda x: x["idx"])
examples = [] examples = []
for i, example in enumerate(dataset): for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -343,6 +349,8 @@ class MrpcProcessor(DefaultGLUEDataProcessor): ...@@ -343,6 +349,8 @@ class MrpcProcessor(DefaultGLUEDataProcessor):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load( dataset = tfds.load(
"glue/mrpc", split=set_type, try_gcs=True).as_numpy_iterator() "glue/mrpc", split=set_type, try_gcs=True).as_numpy_iterator()
dataset = list(dataset)
dataset.sort(key=lambda x: x["idx"])
examples = [] examples = []
for i, example in enumerate(dataset): for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -453,6 +461,8 @@ class QnliProcessor(DefaultGLUEDataProcessor): ...@@ -453,6 +461,8 @@ class QnliProcessor(DefaultGLUEDataProcessor):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load( dataset = tfds.load(
"glue/qnli", split=set_type, try_gcs=True).as_numpy_iterator() "glue/qnli", split=set_type, try_gcs=True).as_numpy_iterator()
dataset = list(dataset)
dataset.sort(key=lambda x: x["idx"])
examples = [] examples = []
for i, example in enumerate(dataset): for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -484,6 +494,8 @@ class QqpProcessor(DefaultGLUEDataProcessor): ...@@ -484,6 +494,8 @@ class QqpProcessor(DefaultGLUEDataProcessor):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load( dataset = tfds.load(
"glue/qqp", split=set_type, try_gcs=True).as_numpy_iterator() "glue/qqp", split=set_type, try_gcs=True).as_numpy_iterator()
dataset = list(dataset)
dataset.sort(key=lambda x: x["idx"])
examples = [] examples = []
for i, example in enumerate(dataset): for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -517,6 +529,8 @@ class RteProcessor(DefaultGLUEDataProcessor): ...@@ -517,6 +529,8 @@ class RteProcessor(DefaultGLUEDataProcessor):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load( dataset = tfds.load(
"glue/rte", split=set_type, try_gcs=True).as_numpy_iterator() "glue/rte", split=set_type, try_gcs=True).as_numpy_iterator()
dataset = list(dataset)
dataset.sort(key=lambda x: x["idx"])
examples = [] examples = []
for i, example in enumerate(dataset): for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -548,6 +562,8 @@ class SstProcessor(DefaultGLUEDataProcessor): ...@@ -548,6 +562,8 @@ class SstProcessor(DefaultGLUEDataProcessor):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load( dataset = tfds.load(
"glue/sst2", split=set_type, try_gcs=True).as_numpy_iterator() "glue/sst2", split=set_type, try_gcs=True).as_numpy_iterator()
dataset = list(dataset)
dataset.sort(key=lambda x: x["idx"])
examples = [] examples = []
for i, example in enumerate(dataset): for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -574,6 +590,8 @@ class StsBProcessor(DefaultGLUEDataProcessor): ...@@ -574,6 +590,8 @@ class StsBProcessor(DefaultGLUEDataProcessor):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load( dataset = tfds.load(
"glue/stsb", split=set_type, try_gcs=True).as_numpy_iterator() "glue/stsb", split=set_type, try_gcs=True).as_numpy_iterator()
dataset = list(dataset)
dataset.sort(key=lambda x: x["idx"])
examples = [] examples = []
for i, example in enumerate(dataset): for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
...@@ -742,6 +760,8 @@ class WnliProcessor(DefaultGLUEDataProcessor): ...@@ -742,6 +760,8 @@ class WnliProcessor(DefaultGLUEDataProcessor):
"""Creates examples for the training/dev/test sets.""" """Creates examples for the training/dev/test sets."""
dataset = tfds.load( dataset = tfds.load(
"glue/wnli", split=set_type, try_gcs=True).as_numpy_iterator() "glue/wnli", split=set_type, try_gcs=True).as_numpy_iterator()
dataset = list(dataset)
dataset.sort(key=lambda x: x["idx"])
examples = [] examples = []
for i, example in enumerate(dataset): for i, example in enumerate(dataset):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,8 +21,8 @@ from absl.testing import parameterized ...@@ -21,8 +21,8 @@ from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib from official.nlp.data import classifier_data_lib
from official.nlp.tools import tokenization
def decode_record(record, name_to_features): def decode_record(record, name_to_features):
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -22,7 +22,6 @@ import os ...@@ -22,7 +22,6 @@ import os
from absl import app from absl import app
from absl import flags from absl import flags
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib from official.nlp.data import classifier_data_lib
from official.nlp.data import sentence_retrieval_lib from official.nlp.data import sentence_retrieval_lib
# word-piece tokenizer based squad_lib # word-piece tokenizer based squad_lib
...@@ -30,10 +29,10 @@ from official.nlp.data import squad_lib as squad_lib_wp ...@@ -30,10 +29,10 @@ from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib # sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp from official.nlp.data import squad_lib_sp
from official.nlp.data import tagging_data_lib from official.nlp.data import tagging_data_lib
from official.nlp.tools import tokenization
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
# TODO(chendouble): consider moving each task to its own binary.
flags.DEFINE_enum( flags.DEFINE_enum(
"fine_tuning_task_type", "classification", "fine_tuning_task_type", "classification",
["classification", "regression", "squad", "retrieval", "tagging"], ["classification", "regression", "squad", "retrieval", "tagging"],
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -24,7 +24,7 @@ from absl import flags ...@@ -24,7 +24,7 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization from official.nlp.tools import tokenization
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Create LM TF examples for XLNet.""" """Create LM TF examples for XLNet."""
import dataclasses
import json import json
import math import math
import os import os
...@@ -28,11 +29,10 @@ from absl import app ...@@ -28,11 +29,10 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import dataclasses
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization from official.nlp.tools import tokenization
special_symbols = { special_symbols = {
"<unk>": 0, "<unk>": 0,
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -124,7 +124,7 @@ class DualEncoderDataLoader(data_loader.DataLoader): ...@@ -124,7 +124,7 @@ class DualEncoderDataLoader(data_loader.DataLoader):
raise ValueError('Expected {} to start with {}'.format(string, old)) raise ValueError('Expected {} to start with {}'.format(string, old))
def _switch_key_prefix(d, old, new): def _switch_key_prefix(d, old, new):
return {_switch_prefix(key, old, new): value for key, value in d.items()} return {_switch_prefix(key, old, new): value for key, value in d.items()} # pytype: disable=attribute-error # trace-all-classes
model_inputs = _switch_key_prefix( model_inputs = _switch_key_prefix(
self._bert_tokenize(record, self._left_text_fields), self._bert_tokenize(record, self._left_text_fields),
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -79,17 +79,29 @@ class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader): ...@@ -79,17 +79,29 @@ class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader):
def _decode(self, record: tf.Tensor): def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example.""" """Decodes a serialized tf.Example."""
name_to_features = { name_to_features = {
'input_ids': tf.io.VarLenFeature(tf.int64),
'input_mask': tf.io.VarLenFeature(tf.int64), 'input_mask': tf.io.VarLenFeature(tf.int64),
'segment_ids': tf.io.VarLenFeature(tf.int64),
'masked_lm_positions': tf.io.VarLenFeature(tf.int64), 'masked_lm_positions': tf.io.VarLenFeature(tf.int64),
'masked_lm_ids': tf.io.VarLenFeature(tf.int64), 'masked_lm_ids': tf.io.VarLenFeature(tf.int64),
'masked_lm_weights': tf.io.VarLenFeature(tf.float32), 'masked_lm_weights': tf.io.VarLenFeature(tf.float32),
} }
if self._params.use_v2_feature_names:
input_ids_key = 'input_word_ids'
segment_key = 'input_type_ids'
name_to_features.update({
input_ids_key: tf.io.VarLenFeature(tf.int64),
segment_key: tf.io.VarLenFeature(tf.int64),
})
else:
input_ids_key = 'input_ids'
segment_key = 'segment_ids'
name_to_features.update({
input_ids_key: tf.io.VarLenFeature(tf.int64),
segment_key: tf.io.VarLenFeature(tf.int64),
})
if self._use_next_sentence_label: if self._use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1], name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64) tf.int64)
dynamic_keys = ['input_ids', 'input_mask', 'segment_ids'] dynamic_keys = [input_ids_key, 'input_mask', segment_key]
if self._use_position_id: if self._use_position_id:
name_to_features['position_ids'] = tf.io.VarLenFeature(tf.int64) name_to_features['position_ids'] = tf.io.VarLenFeature(tf.int64)
dynamic_keys.append('position_ids') dynamic_keys.append('position_ids')
...@@ -102,7 +114,7 @@ class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader): ...@@ -102,7 +114,7 @@ class PretrainingDynamicDataLoader(pretrain_dataloader.BertPretrainDataLoader):
# sequence length dimension. # sequence length dimension.
# Pad before the first non pad from the back should not be removed. # Pad before the first non pad from the back should not be removed.
mask = tf.math.greater( mask = tf.math.greater(
tf.math.cumsum(example['input_ids'], reverse=True), 0) tf.math.cumsum(example[input_ids_key], reverse=True), 0)
for key in dynamic_keys: for key in dynamic_keys:
example[key] = tf.boolean_mask(example[key], mask) example[key] = tf.boolean_mask(example[key], mask)
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # Copyright 2022 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment