Commit 47bc1813 authored by syiming's avatar syiming
Browse files

Merge remote-tracking branch 'upstream/master' into add_multilevel_crop_and_resize

parents d8611151 b035a227
...@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -26,7 +26,7 @@ class BertModelsTest(tf.test.TestCase):
def test_network_invocation(self): def test_network_invocation(self):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1)) encoder=encoders.TransformerEncoderConfig(vocab_size=10, num_layers=1))
_ = bert.instantiate_from_cfg(config) _ = bert.instantiate_bertpretrainer_from_cfg(config)
# Invokes with classification heads. # Invokes with classification heads.
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -35,7 +35,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
_ = bert.instantiate_from_cfg(config) _ = bert.instantiate_bertpretrainer_from_cfg(config)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -47,7 +47,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
_ = bert.instantiate_from_cfg(config) _ = bert.instantiate_bertpretrainer_from_cfg(config)
def test_checkpoint_items(self): def test_checkpoint_items(self):
config = bert.BertPretrainerConfig( config = bert.BertPretrainerConfig(
...@@ -56,7 +56,7 @@ class BertModelsTest(tf.test.TestCase): ...@@ -56,7 +56,7 @@ class BertModelsTest(tf.test.TestCase):
bert.ClsHeadConfig( bert.ClsHeadConfig(
inner_dim=10, num_classes=2, name="next_sentence") inner_dim=10, num_classes=2, name="next_sentence")
]) ])
encoder = bert.instantiate_from_cfg(config) encoder = bert.instantiate_bertpretrainer_from_cfg(config)
self.assertSameElements(encoder.checkpoint_items.keys(), self.assertSameElements(encoder.checkpoint_items.keys(),
["encoder", "next_sentence.pooler_dense"]) ["encoder", "next_sentence.pooler_dense"])
......
...@@ -13,11 +13,17 @@ ...@@ -13,11 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Configurations for Encoders.""" """Transformer Encoders.
Includes configurations and instantiation methods.
"""
import dataclasses import dataclasses
import tensorflow as tf
from official.modeling import tf_utils
from official.modeling.hyperparams import base_config from official.modeling.hyperparams import base_config
from official.nlp.modeling import networks
@dataclasses.dataclass @dataclasses.dataclass
...@@ -28,9 +34,29 @@ class TransformerEncoderConfig(base_config.Config): ...@@ -28,9 +34,29 @@ class TransformerEncoderConfig(base_config.Config):
num_layers: int = 12 num_layers: int = 12
num_attention_heads: int = 12 num_attention_heads: int = 12
hidden_activation: str = "gelu" hidden_activation: str = "gelu"
intermediate_size: int = 3076 intermediate_size: int = 3072
dropout_rate: float = 0.1 dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1
max_position_embeddings: int = 512 max_position_embeddings: int = 512
type_vocab_size: int = 2 type_vocab_size: int = 2
initializer_range: float = 0.02 initializer_range: float = 0.02
def instantiate_encoder_from_cfg(
config: TransformerEncoderConfig) -> networks.TransformerEncoder:
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
encoder_network = networks.TransformerEncoder(
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
num_layers=config.num_layers,
num_attention_heads=config.num_attention_heads,
intermediate_size=config.intermediate_size,
activation=tf_utils.get_activation(config.hidden_activation),
dropout_rate=config.dropout_rate,
attention_dropout_rate=config.attention_dropout_rate,
sequence_length=None,
max_sequence_length=config.max_position_embeddings,
type_vocab_size=config.type_vocab_size,
initializer=tf.keras.initializers.TruncatedNormal(
stddev=config.initializer_range))
return encoder_network
...@@ -33,7 +33,13 @@ from official.nlp.bert import tokenization ...@@ -33,7 +33,13 @@ from official.nlp.bert import tokenization
class InputExample(object): class InputExample(object):
"""A single training/test example for simple sequence classification.""" """A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None, weight=None): def __init__(self,
guid,
text_a,
text_b=None,
label=None,
weight=None,
int_iden=None):
"""Constructs a InputExample. """Constructs a InputExample.
Args: Args:
...@@ -46,12 +52,15 @@ class InputExample(object): ...@@ -46,12 +52,15 @@ class InputExample(object):
specified for train and dev examples, but not for test examples. specified for train and dev examples, but not for test examples.
weight: (Optional) float. The weight of the example to be used during weight: (Optional) float. The weight of the example to be used during
training. training.
int_iden: (Optional) int. The int identification number of example in the
corpus.
""" """
self.guid = guid self.guid = guid
self.text_a = text_a self.text_a = text_a
self.text_b = text_b self.text_b = text_b
self.label = label self.label = label
self.weight = weight self.weight = weight
self.int_iden = int_iden
class InputFeatures(object): class InputFeatures(object):
...@@ -63,13 +72,15 @@ class InputFeatures(object): ...@@ -63,13 +72,15 @@ class InputFeatures(object):
segment_ids, segment_ids,
label_id, label_id,
is_real_example=True, is_real_example=True,
weight=None): weight=None,
int_iden=None):
self.input_ids = input_ids self.input_ids = input_ids
self.input_mask = input_mask self.input_mask = input_mask
self.segment_ids = segment_ids self.segment_ids = segment_ids
self.label_id = label_id self.label_id = label_id
self.is_real_example = is_real_example self.is_real_example = is_real_example
self.weight = weight self.weight = weight
self.int_iden = int_iden
class DataProcessor(object): class DataProcessor(object):
...@@ -191,12 +202,68 @@ class XnliProcessor(DataProcessor): ...@@ -191,12 +202,68 @@ class XnliProcessor(DataProcessor):
return "XNLI" return "XNLI"
class PawsxProcessor(DataProcessor): class XtremeXnliProcessor(DataProcessor):
"""Processor for the PAWS-X data set.""" """Processor for the XTREME XNLI data set."""
supported_languages = [ supported_languages = [
"de", "en", "es", "fr", "ja", "ko", "zh" "ar", "bg", "de", "el", "en", "es", "fr", "hi", "ru", "sw", "th", "tr",
"ur", "vi", "zh"
] ]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = f"test-{i}"
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "contradiction"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["contradiction", "entailment", "neutral"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-XNLI"
class PawsxProcessor(DataProcessor):
"""Processor for the PAWS-X data set."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def __init__(self, def __init__(self,
language="en", language="en",
process_text_fn=tokenization.convert_to_unicode): process_text_fn=tokenization.convert_to_unicode):
...@@ -219,8 +286,7 @@ class PawsxProcessor(DataProcessor): ...@@ -219,8 +286,7 @@ class PawsxProcessor(DataProcessor):
train_tsv = "translated_train.tsv" train_tsv = "translated_train.tsv"
# Skips the header. # Skips the header.
lines.extend( lines.extend(
self._read_tsv( self._read_tsv(os.path.join(data_dir, language, train_tsv))[1:])
os.path.join(data_dir, language, train_tsv))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
...@@ -235,10 +301,9 @@ class PawsxProcessor(DataProcessor): ...@@ -235,10 +301,9 @@ class PawsxProcessor(DataProcessor):
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
lines = [] lines = []
for language in PawsxProcessor.supported_languages: for lang in PawsxProcessor.supported_languages:
# Skips the header.
lines.extend( lines.extend(
self._read_tsv(os.path.join(data_dir, language, "dev_2k.tsv"))[1:]) self._read_tsv(os.path.join(data_dir, lang, "dev_2k.tsv"))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
...@@ -252,17 +317,15 @@ class PawsxProcessor(DataProcessor): ...@@ -252,17 +317,15 @@ class PawsxProcessor(DataProcessor):
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
"""See base class.""" """See base class."""
examples_by_lang = {k: [] for k in PawsxProcessor.supported_languages} examples_by_lang = {k: [] for k in self.supported_languages}
for language in PawsxProcessor.supported_languages: for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, language, "test_2k.tsv")) lines = self._read_tsv(os.path.join(data_dir, lang, "test_2k.tsv"))[1:]
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "test-%d" % i guid = "test-%d" % i
text_a = self.process_text_fn(line[1]) text_a = self.process_text_fn(line[1])
text_b = self.process_text_fn(line[2]) text_b = self.process_text_fn(line[2])
label = self.process_text_fn(line[3]) label = self.process_text_fn(line[3])
examples_by_lang[language].append( examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang return examples_by_lang
...@@ -273,7 +336,62 @@ class PawsxProcessor(DataProcessor): ...@@ -273,7 +336,62 @@ class PawsxProcessor(DataProcessor):
@staticmethod @staticmethod
def get_processor_name(): def get_processor_name():
"""See base class.""" """See base class."""
return "PAWS-X" return "XTREME-PAWS-X"
class XtremePawsxProcessor(DataProcessor):
"""Processor for the XTREME PAWS-X data set."""
supported_languages = ["de", "en", "es", "fr", "ja", "ko", "zh"]
def get_train_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "train-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_dev_examples(self, data_dir):
"""See base class."""
lines = self._read_tsv(os.path.join(data_dir, "dev-en.tsv"))
examples = []
for (i, line) in enumerate(lines):
guid = "dev-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = self.process_text_fn(line[2])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def get_test_examples(self, data_dir):
"""See base class."""
examples_by_lang = {k: [] for k in self.supported_languages}
for lang in self.supported_languages:
lines = self._read_tsv(os.path.join(data_dir, f"test-{lang}.tsv"))
for (i, line) in enumerate(lines):
guid = "test-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
label = "0"
examples_by_lang[lang].append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples_by_lang
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "XTREME-PAWS-X"
class MnliProcessor(DataProcessor): class MnliProcessor(DataProcessor):
...@@ -407,8 +525,8 @@ class QqpProcessor(DataProcessor): ...@@ -407,8 +525,8 @@ class QqpProcessor(DataProcessor):
label = line[5] label = line[5]
except IndexError: except IndexError:
continue continue
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, examples.append(
label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -458,6 +576,53 @@ class ColaProcessor(DataProcessor): ...@@ -458,6 +576,53 @@ class ColaProcessor(DataProcessor):
return examples return examples
class RteProcessor(DataProcessor):
"""Processor for the RTE data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
# All datasets are converted to 2-class split, where for 3-class datasets we
# collapse neutral and contradiction into not_entailment.
return ["entailment", "not_entailment"]
@staticmethod
def get_processor_name():
"""See base class."""
return "RTE"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
if set_type == "test":
label = "entailment"
else:
label = tokenization.convert_to_unicode(line[3])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
class SstProcessor(DataProcessor): class SstProcessor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version).""" """Processor for the SST-2 data set (GLUE version)."""
...@@ -583,15 +748,16 @@ class TfdsProcessor(DataProcessor): ...@@ -583,15 +748,16 @@ class TfdsProcessor(DataProcessor):
is_regression: Whether the task is a regression problem (defaults to False). is_regression: Whether the task is a regression problem (defaults to False).
""" """
def __init__(self, tfds_params, def __init__(self,
tfds_params,
process_text_fn=tokenization.convert_to_unicode): process_text_fn=tokenization.convert_to_unicode):
super(TfdsProcessor, self).__init__(process_text_fn) super(TfdsProcessor, self).__init__(process_text_fn)
self._process_tfds_params_str(tfds_params) self._process_tfds_params_str(tfds_params)
if self.module_import: if self.module_import:
importlib.import_module(self.module_import) importlib.import_module(self.module_import)
self.dataset, info = tfds.load(self.dataset_name, data_dir=self.data_dir, self.dataset, info = tfds.load(
with_info=True) self.dataset_name, data_dir=self.data_dir, with_info=True)
if self.is_regression: if self.is_regression:
self._labels = None self._labels = None
else: else:
...@@ -660,11 +826,60 @@ class TfdsProcessor(DataProcessor): ...@@ -660,11 +826,60 @@ class TfdsProcessor(DataProcessor):
if self.weight_key: if self.weight_key:
weight = float(example[self.weight_key]) weight = float(example[self.weight_key])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, InputExample(
guid=guid,
text_a=text_a,
text_b=text_b,
label=label,
weight=weight)) weight=weight))
return examples return examples
class WnliProcessor(DataProcessor):
"""Processor for the WNLI data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
@staticmethod
def get_processor_name():
"""See base class."""
return "WNLI"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for i, line in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
text_b = tokenization.convert_to_unicode(line[2])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[3])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
def convert_single_example(ex_index, example, label_list, max_seq_length, def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenizer): tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`.""" """Converts a single `InputExample` into a single `InputFeatures`."""
...@@ -748,8 +963,9 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -748,8 +963,9 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids])) logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask])) logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("label: %s (id = %d)", example.label, label_id) logging.info("label: %s (id = %s)", example.label, str(label_id))
logging.info("weight: %s", example.weight) logging.info("weight: %s", example.weight)
logging.info("int_iden: %s", str(example.int_iden))
feature = InputFeatures( feature = InputFeatures(
input_ids=input_ids, input_ids=input_ids,
...@@ -757,13 +973,18 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -757,13 +973,18 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
segment_ids=segment_ids, segment_ids=segment_ids,
label_id=label_id, label_id=label_id,
is_real_example=True, is_real_example=True,
weight=example.weight) weight=example.weight,
int_iden=example.int_iden)
return feature return feature
def file_based_convert_examples_to_features(examples, label_list, def file_based_convert_examples_to_features(examples,
max_seq_length, tokenizer, label_list,
output_file, label_type=None): max_seq_length,
tokenizer,
output_file,
label_type=None):
"""Convert a set of `InputExample`s to a TFRecord file.""" """Convert a set of `InputExample`s to a TFRecord file."""
tf.io.gfile.makedirs(os.path.dirname(output_file)) tf.io.gfile.makedirs(os.path.dirname(output_file))
...@@ -779,6 +1000,7 @@ def file_based_convert_examples_to_features(examples, label_list, ...@@ -779,6 +1000,7 @@ def file_based_convert_examples_to_features(examples, label_list,
def create_int_feature(values): def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f return f
def create_float_feature(values): def create_float_feature(values):
f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return f return f
...@@ -789,12 +1011,14 @@ def file_based_convert_examples_to_features(examples, label_list, ...@@ -789,12 +1011,14 @@ def file_based_convert_examples_to_features(examples, label_list,
features["segment_ids"] = create_int_feature(feature.segment_ids) features["segment_ids"] = create_int_feature(feature.segment_ids)
if label_type is not None and label_type == float: if label_type is not None and label_type == float:
features["label_ids"] = create_float_feature([feature.label_id]) features["label_ids"] = create_float_feature([feature.label_id])
else: elif feature.label_id is not None:
features["label_ids"] = create_int_feature([feature.label_id]) features["label_ids"] = create_int_feature([feature.label_id])
features["is_real_example"] = create_int_feature( features["is_real_example"] = create_int_feature(
[int(feature.is_real_example)]) [int(feature.is_real_example)])
if feature.weight is not None: if feature.weight is not None:
features["weight"] = create_float_feature([feature.weight]) features["weight"] = create_float_feature([feature.weight])
if feature.int_iden is not None:
features["int_iden"] = create_int_feature([feature.int_iden])
tf_example = tf.train.Example(features=tf.train.Features(feature=features)) tf_example = tf.train.Example(features=tf.train.Features(feature=features))
writer.write(tf_example.SerializeToString()) writer.write(tf_example.SerializeToString())
...@@ -830,8 +1054,7 @@ def generate_tf_record_from_data_file(processor, ...@@ -830,8 +1054,7 @@ def generate_tf_record_from_data_file(processor,
Arguments: Arguments:
processor: Input processor object to be used for generating data. Subclass processor: Input processor object to be used for generating data. Subclass
of `DataProcessor`. of `DataProcessor`.
data_dir: Directory that contains train/eval data to process. Data files data_dir: Directory that contains train/eval/test data to process.
should be in from "dev.tsv", "test.tsv", or "train.tsv".
tokenizer: The tokenizer to be applied on the data. tokenizer: The tokenizer to be applied on the data.
train_data_output_path: Output to which processed tf record for training train_data_output_path: Output to which processed tf record for training
will be saved. will be saved.
...@@ -857,8 +1080,7 @@ def generate_tf_record_from_data_file(processor, ...@@ -857,8 +1080,7 @@ def generate_tf_record_from_data_file(processor,
train_input_data_examples = processor.get_train_examples(data_dir) train_input_data_examples = processor.get_train_examples(data_dir)
file_based_convert_examples_to_features(train_input_data_examples, label_list, file_based_convert_examples_to_features(train_input_data_examples, label_list,
max_seq_length, tokenizer, max_seq_length, tokenizer,
train_data_output_path, train_data_output_path, label_type)
label_type)
num_training_data = len(train_input_data_examples) num_training_data = len(train_input_data_examples)
if eval_data_output_path: if eval_data_output_path:
...@@ -868,26 +1090,27 @@ def generate_tf_record_from_data_file(processor, ...@@ -868,26 +1090,27 @@ def generate_tf_record_from_data_file(processor,
tokenizer, eval_data_output_path, tokenizer, eval_data_output_path,
label_type) label_type)
meta_data = {
"processor_type": processor.get_processor_name(),
"train_data_size": num_training_data,
"max_seq_length": max_seq_length,
}
if test_data_output_path: if test_data_output_path:
test_input_data_examples = processor.get_test_examples(data_dir) test_input_data_examples = processor.get_test_examples(data_dir)
if isinstance(test_input_data_examples, dict): if isinstance(test_input_data_examples, dict):
for language, examples in test_input_data_examples.items(): for language, examples in test_input_data_examples.items():
file_based_convert_examples_to_features( file_based_convert_examples_to_features(
examples, examples, label_list, max_seq_length, tokenizer,
label_list, max_seq_length, test_data_output_path.format(language), label_type)
tokenizer, test_data_output_path.format(language), meta_data["test_{}_data_size".format(language)] = len(examples)
label_type)
else: else:
file_based_convert_examples_to_features(test_input_data_examples, file_based_convert_examples_to_features(test_input_data_examples,
label_list, max_seq_length, label_list, max_seq_length,
tokenizer, test_data_output_path, tokenizer, test_data_output_path,
label_type) label_type)
meta_data["test_data_size"] = len(test_input_data_examples)
meta_data = {
"processor_type": processor.get_processor_name(),
"train_data_size": num_training_data,
"max_seq_length": max_seq_length,
}
if is_regression: if is_regression:
meta_data["task_type"] = "bert_regression" meta_data["task_type"] = "bert_regression"
meta_data["label_type"] = {int: "int", float: "float"}[label_type] meta_data["label_type"] = {int: "int", float: "float"}[label_type]
...@@ -900,12 +1123,4 @@ def generate_tf_record_from_data_file(processor, ...@@ -900,12 +1123,4 @@ def generate_tf_record_from_data_file(processor,
if eval_data_output_path: if eval_data_output_path:
meta_data["eval_data_size"] = len(eval_input_data_examples) meta_data["eval_data_size"] = len(eval_input_data_examples)
if test_data_output_path:
test_input_data_examples = processor.get_test_examples(data_dir)
if isinstance(test_input_data_examples, dict):
for language, examples in test_input_data_examples.items():
meta_data["test_{}_data_size".format(language)] = len(examples)
else:
meta_data["test_data_size"] = len(test_input_data_examples)
return meta_data return meta_data
...@@ -27,18 +27,21 @@ from absl import flags ...@@ -27,18 +27,21 @@ from absl import flags
import tensorflow as tf import tensorflow as tf
from official.nlp.bert import tokenization from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib from official.nlp.data import classifier_data_lib
from official.nlp.data import sentence_retrieval_lib
# word-piece tokenizer based squad_lib # word-piece tokenizer based squad_lib
from official.nlp.data import squad_lib as squad_lib_wp from official.nlp.data import squad_lib as squad_lib_wp
# sentence-piece tokenizer based squad_lib # sentence-piece tokenizer based squad_lib
from official.nlp.data import squad_lib_sp from official.nlp.data import squad_lib_sp
from official.nlp.data import tagging_data_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
# TODO(chendouble): consider moving each task to its own binary.
flags.DEFINE_enum( flags.DEFINE_enum(
"fine_tuning_task_type", "classification", "fine_tuning_task_type", "classification",
["classification", "regression", "squad"], ["classification", "regression", "squad", "retrieval", "tagging"],
"The name of the BERT fine tuning task for which data " "The name of the BERT fine tuning task for which data "
"will be generated..") "will be generated.")
# BERT classification specific flags. # BERT classification specific flags.
flags.DEFINE_string( flags.DEFINE_string(
...@@ -48,8 +51,12 @@ flags.DEFINE_string( ...@@ -48,8 +51,12 @@ flags.DEFINE_string(
flags.DEFINE_enum("classification_task_name", "MNLI", flags.DEFINE_enum("classification_task_name", "MNLI",
["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI", ["COLA", "MNLI", "MRPC", "QNLI", "QQP", "SST-2", "XNLI",
"PAWS-X"], "PAWS-X", "XTREME-XNLI", "XTREME-PAWS-X"],
"The name of the task to train BERT classifier.") "The name of the task to train BERT classifier. The "
"difference between XTREME-XNLI and XNLI is: 1. the format "
"of input tsv files; 2. the dev set for XTREME is english "
"only and for XNLI is all languages combined. Same for "
"PAWS-X.")
# XNLI task specific flag. # XNLI task specific flag.
flags.DEFINE_string( flags.DEFINE_string(
...@@ -63,6 +70,14 @@ flags.DEFINE_string( ...@@ -63,6 +70,14 @@ flags.DEFINE_string(
"Language of trainig data for PAWS-X task. If the value is 'all', the data " "Language of trainig data for PAWS-X task. If the value is 'all', the data "
"of all languages will be used for training.") "of all languages will be used for training.")
# Retrieva task specific flags
flags.DEFINE_enum("retrieval_task_name", "bucc", ["bucc", "tatoeba"],
"The name of sentence retrieval task for scoring")
# Tagging task specific flags
flags.DEFINE_enum("tagging_task_name", "panx", ["panx", "udpos"],
"The name of BERT tagging (token classification) task.")
# BERT Squad task specific flags. # BERT Squad task specific flags.
flags.DEFINE_string( flags.DEFINE_string(
"squad_data_file", None, "squad_data_file", None,
...@@ -169,6 +184,7 @@ def generate_classifier_dataset(): ...@@ -169,6 +184,7 @@ def generate_classifier_dataset():
"qnli": "qnli":
classifier_data_lib.QnliProcessor, classifier_data_lib.QnliProcessor,
"qqp": classifier_data_lib.QqpProcessor, "qqp": classifier_data_lib.QqpProcessor,
"rte": classifier_data_lib.RteProcessor,
"sst-2": "sst-2":
classifier_data_lib.SstProcessor, classifier_data_lib.SstProcessor,
"xnli": "xnli":
...@@ -176,7 +192,12 @@ def generate_classifier_dataset(): ...@@ -176,7 +192,12 @@ def generate_classifier_dataset():
language=FLAGS.xnli_language), language=FLAGS.xnli_language),
"paws-x": "paws-x":
functools.partial(classifier_data_lib.PawsxProcessor, functools.partial(classifier_data_lib.PawsxProcessor,
language=FLAGS.pawsx_language) language=FLAGS.pawsx_language),
"wnli": classifier_data_lib.WnliProcessor,
"xtreme-xnli":
functools.partial(classifier_data_lib.XtremeXnliProcessor),
"xtreme-paws-x":
functools.partial(classifier_data_lib.XtremePawsxProcessor)
} }
task_name = FLAGS.classification_task_name.lower() task_name = FLAGS.classification_task_name.lower()
if task_name not in processors: if task_name not in processors:
...@@ -237,6 +258,67 @@ def generate_squad_dataset(): ...@@ -237,6 +258,67 @@ def generate_squad_dataset():
FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative) FLAGS.max_query_length, FLAGS.doc_stride, FLAGS.version_2_with_negative)
def generate_retrieval_dataset():
"""Generate retrieval test and dev dataset and returns input meta data."""
assert (FLAGS.input_data_dir and FLAGS.retrieval_task_name)
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
else:
assert FLAGS.tokenizer_impl == "sentence_piece"
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
processors = {
"bucc": sentence_retrieval_lib.BuccProcessor,
"tatoeba": sentence_retrieval_lib.TatoebaProcessor,
}
task_name = FLAGS.retrieval_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
processor = processors[task_name](process_text_fn=processor_text_fn)
return sentence_retrieval_lib.generate_sentence_retrevial_tf_record(
processor,
FLAGS.input_data_dir,
tokenizer,
FLAGS.eval_data_output_path,
FLAGS.test_data_output_path,
FLAGS.max_seq_length)
def generate_tagging_dataset():
"""Generates tagging dataset."""
processors = {
"panx": tagging_data_lib.PanxProcessor,
"udpos": tagging_data_lib.UdposProcessor,
}
task_name = FLAGS.tagging_task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % task_name)
if FLAGS.tokenizer_impl == "word_piece":
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
processor_text_fn = tokenization.convert_to_unicode
elif FLAGS.tokenizer_impl == "sentence_piece":
tokenizer = tokenization.FullSentencePieceTokenizer(FLAGS.sp_model_file)
processor_text_fn = functools.partial(
tokenization.preprocess_text, lower=FLAGS.do_lower_case)
else:
raise ValueError("Unsupported tokenizer_impl: %s" % FLAGS.tokenizer_impl)
processor = processors[task_name]()
return tagging_data_lib.generate_tf_record_from_data_file(
processor, FLAGS.input_data_dir, tokenizer, FLAGS.max_seq_length,
FLAGS.train_data_output_path, FLAGS.eval_data_output_path,
FLAGS.test_data_output_path, processor_text_fn)
def main(_): def main(_):
if FLAGS.tokenizer_impl == "word_piece": if FLAGS.tokenizer_impl == "word_piece":
if not FLAGS.vocab_file: if not FLAGS.vocab_file:
...@@ -248,12 +330,20 @@ def main(_): ...@@ -248,12 +330,20 @@ def main(_):
raise ValueError( raise ValueError(
"FLAG sp_model_file for sentence-piece tokenizer is not specified.") "FLAG sp_model_file for sentence-piece tokenizer is not specified.")
if FLAGS.fine_tuning_task_type != "retrieval":
flags.mark_flag_as_required("train_data_output_path")
if FLAGS.fine_tuning_task_type == "classification": if FLAGS.fine_tuning_task_type == "classification":
input_meta_data = generate_classifier_dataset() input_meta_data = generate_classifier_dataset()
elif FLAGS.fine_tuning_task_type == "regression": elif FLAGS.fine_tuning_task_type == "regression":
input_meta_data = generate_regression_dataset() input_meta_data = generate_regression_dataset()
else: elif FLAGS.fine_tuning_task_type == "retrieval":
input_meta_data = generate_retrieval_dataset()
elif FLAGS.fine_tuning_task_type == "squad":
input_meta_data = generate_squad_dataset() input_meta_data = generate_squad_dataset()
else:
assert FLAGS.fine_tuning_task_type == "tagging"
input_meta_data = generate_tagging_dataset()
tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path)) tf.io.gfile.makedirs(os.path.dirname(FLAGS.meta_data_file_path))
with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer: with tf.io.gfile.GFile(FLAGS.meta_data_file_path, "w") as writer:
...@@ -261,6 +351,5 @@ def main(_): ...@@ -261,6 +351,5 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
flags.mark_flag_as_required("train_data_output_path")
flags.mark_flag_as_required("meta_data_file_path") flags.mark_flag_as_required("meta_data_file_path")
app.run(main) app.run(main)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BERT library to process data for cross lingual sentence retrieval task."""
import os
from absl import logging
from official.nlp.bert import tokenization
from official.nlp.data import classifier_data_lib
class BuccProcessor(classifier_data_lib.DataProcessor):
"""Procssor for Xtreme BUCC data set."""
supported_languages = ["de", "fr", "ru", "zh"]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode):
super(BuccProcessor, self).__init__(process_text_fn)
self.languages = BuccProcessor.supported_languages
def get_dev_examples(self, data_dir, file_pattern):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, file_pattern.format("dev"))),
"sample")
def get_test_examples(self, data_dir, file_pattern):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, file_pattern.format("test"))),
"test")
@staticmethod
def get_processor_name():
"""See base class."""
return "BUCC"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
int_iden = int(line[0].split("-")[1])
text_a = self.process_text_fn(line[1])
examples.append(
classifier_data_lib.InputExample(
guid=guid, text_a=text_a, int_iden=int_iden))
return examples
class TatoebaProcessor(classifier_data_lib.DataProcessor):
"""Procssor for Xtreme Tatoeba data set."""
supported_languages = [
"af", "ar", "bg", "bn", "de", "el", "es", "et", "eu", "fa", "fi", "fr",
"he", "hi", "hu", "id", "it", "ja", "jv", "ka", "kk", "ko", "ml", "mr",
"nl", "pt", "ru", "sw", "ta", "te", "th", "tl", "tr", "ur", "vi", "zh"
]
def __init__(self,
process_text_fn=tokenization.convert_to_unicode):
super(TatoebaProcessor, self).__init__(process_text_fn)
self.languages = TatoebaProcessor.supported_languages
def get_test_examples(self, data_dir, file_path):
return self._create_examples(
self._read_tsv(os.path.join(data_dir, file_path)), "test")
@staticmethod
def get_processor_name():
"""See base class."""
return "TATOEBA"
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = self.process_text_fn(line[0])
examples.append(
classifier_data_lib.InputExample(
guid=guid, text_a=text_a, int_iden=i))
return examples
def generate_sentence_retrevial_tf_record(processor,
data_dir,
tokenizer,
eval_data_output_path=None,
test_data_output_path=None,
max_seq_length=128):
"""Generates the tf records for retrieval tasks.
Args:
processor: Input processor object to be used for generating data. Subclass
of `DataProcessor`.
data_dir: Directory that contains train/eval data to process. Data files
should be in from.
tokenizer: The tokenizer to be applied on the data.
eval_data_output_path: Output to which processed tf record for evaluation
will be saved.
test_data_output_path: Output to which processed tf record for testing
will be saved. Must be a pattern template with {} if processor has
language specific test data.
max_seq_length: Maximum sequence length of the to be generated
training/eval data.
Returns:
A dictionary containing input meta data.
"""
assert eval_data_output_path or test_data_output_path
if processor.get_processor_name() == "BUCC":
path_pattern = "{}-en.{{}}.{}"
if processor.get_processor_name() == "TATOEBA":
path_pattern = "{}-en.{}"
meta_data = {
"processor_type": processor.get_processor_name(),
"max_seq_length": max_seq_length,
"number_eval_data": {},
"number_test_data": {},
}
logging.info("Start to process %s task data", processor.get_processor_name())
for lang_a in processor.languages:
for lang_b in [lang_a, "en"]:
if eval_data_output_path:
eval_input_data_examples = processor.get_dev_examples(
data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
num_eval_data = len(eval_input_data_examples)
logging.info("Processing %d dev examples of %s-en.%s", num_eval_data,
lang_a, lang_b)
output_file = os.path.join(
eval_data_output_path,
"{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "dev"))
classifier_data_lib.file_based_convert_examples_to_features(
eval_input_data_examples, None, max_seq_length, tokenizer,
output_file, None)
meta_data["number_eval_data"][f"{lang_a}-en.{lang_b}"] = num_eval_data
if test_data_output_path:
test_input_data_examples = processor.get_test_examples(
data_dir, os.path.join(path_pattern.format(lang_a, lang_b)))
num_test_data = len(test_input_data_examples)
logging.info("Processing %d test examples of %s-en.%s", num_test_data,
lang_a, lang_b)
output_file = os.path.join(
test_data_output_path,
"{}-en-{}.{}.tfrecords".format(lang_a, lang_b, "test"))
classifier_data_lib.file_based_convert_examples_to_features(
test_input_data_examples, None, max_seq_length, tokenizer,
output_file, None)
meta_data["number_test_data"][f"{lang_a}-en.{lang_b}"] = num_test_data
return meta_data
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library to process data for tagging task such as NER/POS."""
import collections
import os
from absl import logging
import tensorflow as tf
from official.nlp.data import classifier_data_lib
# A negative label id for the padding label, which will not contribute
# to loss/metrics in training.
_PADDING_LABEL_ID = -1
# The special unknown token, used to substitute a word which has too many
# subwords after tokenization.
_UNK_TOKEN = "[UNK]"
class InputExample(object):
"""A single training/test example for token classification."""
def __init__(self, sentence_id, words=None, label_ids=None):
"""Constructs an InputExample."""
self.sentence_id = sentence_id
self.words = words if words else []
self.label_ids = label_ids if label_ids else []
def add_word_and_label_id(self, word, label_id):
"""Adds word and label_id pair in the example."""
self.words.append(word)
self.label_ids.append(label_id)
def _read_one_file(file_name, label_list):
"""Reads one file and returns a list of `InputExample` instances."""
lines = tf.io.gfile.GFile(file_name, "r").readlines()
examples = []
label_id_map = {label: i for i, label in enumerate(label_list)}
sentence_id = 0
example = InputExample(sentence_id=0)
for line in lines:
line = line.strip("\n")
if line:
# The format is: <token>\t<label> for train/dev set and <token> for test.
items = line.split("\t")
assert len(items) == 2 or len(items) == 1
token = items[0].strip()
# Assign a dummy label_id for test set
label_id = label_id_map[items[1].strip()] if len(items) == 2 else 0
example.add_word_and_label_id(token, label_id)
else:
# Empty line indicates a new sentence.
if example.words:
examples.append(example)
sentence_id += 1
example = InputExample(sentence_id=sentence_id)
if example.words:
examples.append(example)
return examples
class PanxProcessor(classifier_data_lib.DataProcessor):
"""Processor for the Panx data set."""
supported_languages = [
"ar", "he", "vi", "id", "jv", "ms", "tl", "eu", "ml", "ta", "te", "af",
"nl", "en", "de", "el", "bn", "hi", "mr", "ur", "fa", "fr", "it", "pt",
"es", "bg", "ru", "ja", "ka", "ko", "th", "sw", "yo", "my", "zh", "kk",
"tr", "et", "fi", "hu"
]
def get_train_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "train-en.tsv"), self.get_labels())
def get_dev_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
def get_test_examples(self, data_dir):
examples_dict = {}
for language in self.supported_languages:
examples_dict[language] = _read_one_file(
os.path.join(data_dir, "test-%s.tsv" % language), self.get_labels())
return examples_dict
def get_labels(self):
return ["O", "B-PER", "I-PER", "B-LOC", "I-LOC", "B-ORG", "I-ORG"]
@staticmethod
def get_processor_name():
return "panx"
class UdposProcessor(classifier_data_lib.DataProcessor):
"""Processor for the Udpos data set."""
supported_languages = [
"af", "ar", "bg", "de", "el", "en", "es", "et", "eu", "fa", "fi", "fr",
"he", "hi", "hu", "id", "it", "ja", "kk", "ko", "mr", "nl", "pt", "ru",
"ta", "te", "th", "tl", "tr", "ur", "vi", "yo", "zh"
]
def get_train_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "train-en.tsv"), self.get_labels())
def get_dev_examples(self, data_dir):
return _read_one_file(
os.path.join(data_dir, "dev-en.tsv"), self.get_labels())
def get_test_examples(self, data_dir):
examples_dict = {}
for language in self.supported_languages:
examples_dict[language] = _read_one_file(
os.path.join(data_dir, "test-%s.tsv" % language), self.get_labels())
return examples_dict
def get_labels(self):
return [
"ADJ", "ADP", "ADV", "AUX", "CCONJ", "DET", "INTJ", "NOUN", "NUM",
"PART", "PRON", "PROPN", "PUNCT", "SCONJ", "SYM", "VERB", "X"
]
@staticmethod
def get_processor_name():
return "udpos"
def _tokenize_example(example, max_length, tokenizer, text_preprocessing=None):
"""Tokenizes words and breaks long example into short ones."""
# Needs additional [CLS] and [SEP] tokens.
max_length = max_length - 2
new_examples = []
new_example = InputExample(sentence_id=example.sentence_id)
for i, word in enumerate(example.words):
if any([x < 0 for x in example.label_ids]):
raise ValueError("Unexpected negative label_id: %s" % example.label_ids)
if text_preprocessing:
word = text_preprocessing(word)
subwords = tokenizer.tokenize(word)
if (not subwords or len(subwords) > max_length) and word:
subwords = [_UNK_TOKEN]
if len(subwords) + len(new_example.words) > max_length:
# Start a new example.
new_examples.append(new_example)
new_example = InputExample(sentence_id=example.sentence_id)
for j, subword in enumerate(subwords):
# Use the real label for the first subword, and pad label for
# the remainings.
subword_label = example.label_ids[i] if j == 0 else _PADDING_LABEL_ID
new_example.add_word_and_label_id(subword, subword_label)
if new_example.words:
new_examples.append(new_example)
return new_examples
def _convert_single_example(example, max_seq_length, tokenizer):
"""Converts an `InputExample` instance to a `tf.train.Example` instance."""
tokens = ["[CLS]"]
tokens.extend(example.words)
tokens.append("[SEP]")
input_ids = tokenizer.convert_tokens_to_ids(tokens)
label_ids = [_PADDING_LABEL_ID]
label_ids.extend(example.label_ids)
label_ids.append(_PADDING_LABEL_ID)
segment_ids = [0] * len(input_ids)
input_mask = [1] * len(input_ids)
# Pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_ids.append(_PADDING_LABEL_ID)
def create_int_feature(values):
return tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
features = collections.OrderedDict()
features["input_ids"] = create_int_feature(input_ids)
features["input_mask"] = create_int_feature(input_mask)
features["segment_ids"] = create_int_feature(segment_ids)
features["label_ids"] = create_int_feature(label_ids)
features["sentence_id"] = create_int_feature([example.sentence_id])
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
return tf_example
def write_example_to_file(examples,
tokenizer,
max_seq_length,
output_file,
text_preprocessing=None):
"""Writes `InputExample`s into a tfrecord file with `tf.train.Example` protos.
Note that the words inside each example will be tokenized and be applied by
`text_preprocessing` if available. Also, if the length of sentence (plus
special [CLS] and [SEP] tokens) exceeds `max_seq_length`, the long sentence
will be broken into multiple short examples. For example:
Example (text_preprocessing=lowercase, max_seq_length=5)
words: ["What", "a", "great", "weekend"]
labels: [ 7, 5, 9, 10]
sentence_id: 0
preprocessed: ["what", "a", "great", "weekend"]
tokenized: ["what", "a", "great", "week", "##end"]
will result in two tf.example protos:
tokens: ["[CLS]", "what", "a", "great", "[SEP]"]
label_ids: [-1, 7, 5, 9, -1]
input_mask: [ 1, 1, 1, 1, 1]
segment_ids: [ 0, 0, 0, 0, 0]
input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
sentence_id: 0
tokens: ["[CLS]", "week", "##end", "[SEP]", "[PAD]"]
label_ids: [-1, 10, -1, -1, -1]
input_mask: [ 1, 1, 1, 0, 0]
segment_ids: [ 0, 0, 0, 0, 0]
input_ids: [ tokenizer.convert_tokens_to_ids(tokens) ]
sentence_id: 0
Note the use of -1 in `label_ids` to indicate that a token should not be
considered for classification (e.g., trailing ## wordpieces or special
token). Token classification models should accordingly ignore these when
calculating loss, metrics, etc...
Args:
examples: A list of `InputExample` instances.
tokenizer: The tokenizer to be applied on the data.
max_seq_length: Maximum length of generated sequences.
output_file: The name of the output tfrecord file.
text_preprocessing: optional preprocessing run on each word prior to
tokenization.
Returns:
The total number of tf.train.Example proto written to file.
"""
tf.io.gfile.makedirs(os.path.dirname(output_file))
writer = tf.io.TFRecordWriter(output_file)
num_tokenized_examples = 0
for (ex_index, example) in enumerate(examples):
if ex_index % 10000 == 0:
logging.info("Writing example %d of %d to %s", ex_index, len(examples),
output_file)
tokenized_examples = _tokenize_example(example, max_seq_length,
tokenizer, text_preprocessing)
num_tokenized_examples += len(tokenized_examples)
for per_tokenized_example in tokenized_examples:
tf_example = _convert_single_example(
per_tokenized_example, max_seq_length, tokenizer)
writer.write(tf_example.SerializeToString())
writer.close()
return num_tokenized_examples
def token_classification_meta_data(train_data_size,
max_seq_length,
num_labels,
eval_data_size=None,
test_data_size=None,
label_list=None,
processor_type=None):
"""Creates metadata for tagging (token classification) datasets."""
meta_data = {
"train_data_size": train_data_size,
"max_seq_length": max_seq_length,
"num_labels": num_labels,
"task_type": "tagging",
"label_type": "int",
"label_shape": [max_seq_length],
}
if eval_data_size:
meta_data["eval_data_size"] = eval_data_size
if test_data_size:
meta_data["test_data_size"] = test_data_size
if label_list:
meta_data["label_list"] = label_list
if processor_type:
meta_data["processor_type"] = processor_type
return meta_data
def generate_tf_record_from_data_file(processor,
data_dir,
tokenizer,
max_seq_length,
train_data_output_path,
eval_data_output_path,
test_data_output_path,
text_preprocessing):
"""Generates tfrecord files from the raw data."""
common_kwargs = dict(tokenizer=tokenizer, max_seq_length=max_seq_length,
text_preprocessing=text_preprocessing)
train_examples = processor.get_train_examples(data_dir)
train_data_size = write_example_to_file(
train_examples, output_file=train_data_output_path, **common_kwargs)
eval_examples = processor.get_dev_examples(data_dir)
eval_data_size = write_example_to_file(
eval_examples, output_file=eval_data_output_path, **common_kwargs)
test_input_data_examples = processor.get_test_examples(data_dir)
test_data_size = {}
for language, examples in test_input_data_examples.items():
test_data_size[language] = write_example_to_file(
examples,
output_file=test_data_output_path.format(language),
**common_kwargs)
labels = processor.get_labels()
meta_data = token_classification_meta_data(
train_data_size,
max_seq_length,
len(labels),
eval_data_size,
test_data_size,
label_list=labels,
processor_type=processor.get_processor_name())
return meta_data
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loads dataset for the tagging (e.g., NER/POS) task."""
from typing import Mapping, Optional
import tensorflow as tf
from official.core import input_reader
class TaggingDataLoader:
"""A class to load dataset for tagging (e.g., NER and POS) task."""
def __init__(self, params):
self._params = params
self._seq_length = params.seq_length
def _decode(self, record: tf.Tensor):
"""Decodes a serialized tf.Example."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'input_mask': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'segment_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
'label_ids': tf.io.FixedLenFeature([self._seq_length], tf.int64),
}
example = tf.io.parse_single_example(record, name_to_features)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for name in example:
t = example[name]
if t.dtype == tf.int64:
t = tf.cast(t, tf.int32)
example[name] = t
return example
def _parse(self, record: Mapping[str, tf.Tensor]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x = {
'input_word_ids': record['input_ids'],
'input_mask': record['input_mask'],
'input_type_ids': record['segment_ids']
}
y = record['label_ids']
return (x, y)
def load(self, input_context: Optional[tf.distribute.InputContext] = None):
"""Returns a tf.dataset.Dataset."""
reader = input_reader.InputReader(
params=self._params, decoder_fn=self._decode, parser_fn=self._parse)
return reader.read(input_context)
...@@ -9,13 +9,17 @@ assemble new layers, networks, or models. ...@@ -9,13 +9,17 @@ assemble new layers, networks, or models.
initialization parameters. initialization parameters.
* [MultiHeadAttention](attention.py) implements an optionally masked attention * [MultiHeadAttention](attention.py) implements an optionally masked attention
between two tensors, from_tensor and to_tensor, as described in between query, key, value tensors as described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). If
`from_tensor` and `to_tensor` are the same, then this is self-attention. `from_tensor` and `to_tensor` are the same, then this is self-attention.
* [CachedAttention](attention.py) implements an attention layer with cache * [CachedAttention](attention.py) implements an attention layer with cache
used for auto-agressive decoding. used for auto-agressive decoding.
* [MultiChannelAttention](multi_channel_attention.py) implements an variant of
multi-head attention which can be used to merge multiple streams for
cross-attentions.
* [TalkingHeadsAttention](talking_heads_attention.py) implements the talking * [TalkingHeadsAttention](talking_heads_attention.py) implements the talking
heads attention, as decribed in heads attention, as decribed in
["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436). ["Talking-Heads Attention"](https://arxiv.org/abs/2003.02436).
...@@ -24,6 +28,10 @@ assemble new layers, networks, or models. ...@@ -24,6 +28,10 @@ assemble new layers, networks, or models.
described in described in
["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).
* [TransformerDecoderLayer](transformer.py) TransformerDecoderLayer is made up
of self multi-head attention, cross multi-head attention and
feedforward network.
* [ReZeroTransformer](rezero_transformer.py) implements Transformer with * [ReZeroTransformer](rezero_transformer.py) implements Transformer with
ReZero described in ReZero described in
["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887). ["ReZero is All You Need: Fast Convergence at Large Depth"](https://arxiv.org/abs/2003.04887).
...@@ -45,6 +53,9 @@ assemble new layers, networks, or models. ...@@ -45,6 +53,9 @@ assemble new layers, networks, or models.
should be masked), the output will have masked positions set to should be masked), the output will have masked positions set to
approximately zero. approximately zero.
* [`MaskedLM`](masked_lm.py) implements a masked language model. It assumes
the embedding table variable is passed to it.
* [ClassificationHead](cls_head.py) A pooling head over a sequence of * [ClassificationHead](cls_head.py) A pooling head over a sequence of
embeddings, commonly used by classification tasks. embeddings, commonly used by classification tasks.
......
...@@ -18,11 +18,13 @@ from official.nlp.modeling.layers.attention import * ...@@ -18,11 +18,13 @@ from official.nlp.modeling.layers.attention import *
from official.nlp.modeling.layers.cls_head import * from official.nlp.modeling.layers.cls_head import *
from official.nlp.modeling.layers.dense_einsum import DenseEinsum from official.nlp.modeling.layers.dense_einsum import DenseEinsum
from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward from official.nlp.modeling.layers.gated_feedforward import GatedFeedforward
from official.nlp.modeling.layers.masked_lm import MaskedLM
from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax from official.nlp.modeling.layers.masked_softmax import MaskedSoftmax
from official.nlp.modeling.layers.multi_channel_attention import *
from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding from official.nlp.modeling.layers.on_device_embedding import OnDeviceEmbedding
from official.nlp.modeling.layers.position_embedding import PositionEmbedding from official.nlp.modeling.layers.position_embedding import PositionEmbedding
from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer from official.nlp.modeling.layers.rezero_transformer import ReZeroTransformer
from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask from official.nlp.modeling.layers.self_attention_mask import SelfAttentionMask
from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention from official.nlp.modeling.layers.talking_heads_attention import TalkingHeadsAttention
from official.nlp.modeling.layers.transformer import Transformer from official.nlp.modeling.layers.transformer import *
from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold from official.nlp.modeling.layers.transformer_scaffold import TransformerScaffold
...@@ -25,91 +25,74 @@ from official.modeling import tf_utils ...@@ -25,91 +25,74 @@ from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package='Text') @tf.keras.utils.register_keras_serializable(package='Text')
class MaskedLM(tf.keras.Model): class MaskedLM(tf.keras.layers.Layer):
"""Masked language model network head for BERT modeling. """Masked language model network head for BERT modeling.
This network implements a masked language model based on the provided network. This network implements a masked language model based on the provided network.
It assumes that the network being passed has a "get_embedding_table()" method. It assumes that the network being passed has a "get_embedding_table()" method.
Arguments: Arguments:
input_width: The innermost dimension of the input tensor to this network. embedding_table: The embedding table of the targets.
num_predictions: The number of predictions to make per sequence. activation: The activation, if any, for the dense layer.
source_network: The network with the embedding layer to use for the initializer: The intializer for the dense layer. Defaults to a Glorot
embedding layer. uniform initializer.
embedding_table: The embedding table of a source network, If None, the
`source_network.get_embedding_table()` method is used.
activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to
a Glorot uniform initializer.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either 'logits' or
'predictions'. 'predictions'.
""" """
def __init__(self, def __init__(self,
input_width, embedding_table,
num_predictions,
source_network,
embedding_table=None,
activation=None, activation=None,
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
name='cls/predictions',
**kwargs): **kwargs):
super(MaskedLM, self).__init__(name=name, **kwargs)
self.embedding_table = embedding_table
self.activation = activation
self.initializer = tf.keras.initializers.get(initializer)
if embedding_table is None: if output not in ('predictions', 'logits'):
embedding_table = source_network.get_embedding_table()
vocab_size, hidden_size = embedding_table.shape
sequence_data = tf.keras.layers.Input(
shape=(None, input_width), name='sequence_data', dtype=tf.float32)
masked_lm_positions = tf.keras.layers.Input(
shape=(num_predictions,), name='masked_lm_positions', dtype=tf.int32)
masked_lm_input = tf.keras.layers.Lambda(
lambda x: self._gather_indexes(x[0], x[1]))(
[sequence_data, masked_lm_positions])
lm_data = (
tf.keras.layers.Dense(
hidden_size,
activation=activation,
kernel_initializer=initializer,
name='cls/predictions/transform/dense')(masked_lm_input))
lm_data = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='cls/predictions/transform/LayerNorm')(
lm_data)
lm_data = tf.keras.layers.Lambda(
lambda x: tf.matmul(x, embedding_table, transpose_b=True))(
lm_data)
logits = Bias(
initializer=tf.keras.initializers.Zeros(),
name='cls/predictions/output_bias')(
lm_data)
# We can't use the standard Keras reshape layer here, since it expects
# the input and output batch size to be the same.
reshape_layer = tf.keras.layers.Lambda(
lambda x: tf.reshape(x, [-1, num_predictions, vocab_size]))
self.logits = reshape_layer(logits)
predictions = tf.keras.layers.Activation(tf.nn.log_softmax)(self.logits)
if output == 'logits':
output_tensors = self.logits
elif output == 'predictions':
output_tensors = predictions
else:
raise ValueError( raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or ' ('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output) '"predictions"') % output)
self._output_type = output
def build(self, input_shape):
self._vocab_size, hidden_size = self.embedding_table.shape
self.dense = tf.keras.layers.Dense(
hidden_size,
activation=self.activation,
kernel_initializer=self.initializer,
name='transform/dense')
self.layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/LayerNorm')
self.bias = self.add_weight(
'output_bias/bias',
shape=(self._vocab_size,),
initializer='zeros',
trainable=True)
super(MaskedLM, self).build(input_shape)
super(MaskedLM, self).__init__( def call(self, sequence_data, masked_positions):
inputs=[sequence_data, masked_lm_positions], masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
outputs=output_tensors, lm_data = self.dense(masked_lm_input)
**kwargs) lm_data = self.layer_norm(lm_data)
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
masked_positions_shape = tf_utils.get_shape_list(
masked_positions, name='masked_positions_tensor')
logits = tf.reshape(logits,
[-1, masked_positions_shape[1], self._vocab_size])
if self._output_type == 'logits':
return logits
return tf.nn.log_softmax(logits)
def get_config(self): def get_config(self):
raise NotImplementedError('MaskedLM cannot be directly serialized at this ' raise NotImplementedError('MaskedLM cannot be directly serialized because '
'time. Please use it only in Layers or ' 'it has variable sharing logic.')
'functionally subclassed Models/Networks.')
def _gather_indexes(self, sequence_tensor, positions): def _gather_indexes(self, sequence_tensor, positions):
"""Gathers the vectors at the specific positions. """Gathers the vectors at the specific positions.
...@@ -139,51 +122,3 @@ class MaskedLM(tf.keras.Model): ...@@ -139,51 +122,3 @@ class MaskedLM(tf.keras.Model):
output_tensor = tf.gather(flat_sequence_tensor, flat_positions) output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor return output_tensor
@tf.keras.utils.register_keras_serializable(package='Text')
# Temporary until we can create a Dense layer that ties the embedding.
class Bias(tf.keras.layers.Layer):
"""Adds a bias term to an input."""
def __init__(self,
initializer='zeros',
regularizer=None,
constraint=None,
activation=None,
**kwargs):
super(Bias, self).__init__(**kwargs)
self._initializer = tf.keras.initializers.get(initializer)
self._regularizer = tf.keras.regularizers.get(regularizer)
self._constraint = tf.keras.constraints.get(constraint)
self._activation = tf.keras.activations.get(activation)
def build(self, input_shape):
input_shape = tf.TensorShape(input_shape)
self._bias = self.add_weight(
'bias',
shape=input_shape[1:],
initializer=self._initializer,
regularizer=self._regularizer,
constraint=self._constraint,
dtype=self._dtype,
trainable=True)
super(Bias, self).build(input_shape)
def get_config(self):
config = {
'activation': tf.keras.activations.serialize(self._activation),
'initializer': tf.keras.initializers.serialize(self._initializer),
'regularizer': tf.keras.regularizers.serialize(self._regularizer),
'constraint': tf.keras.constraints.serialize(self._constraint)
}
base_config = super(Bias, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
outputs = tf.nn.bias_add(inputs, self._bias)
if self._activation is not None:
return self._activation(outputs) # pylint: disable=not-callable
else:
return outputs
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Multi-channel decoder.""" """Multi-channel Attention."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -24,11 +25,25 @@ import math ...@@ -24,11 +25,25 @@ import math
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import masked_softmax
class DocAttention(tf.keras.layers.Layer):
"""Documents Attention layer."""
class VotingAttention(tf.keras.layers.Layer):
"""Voting Attention layer.
Arguments:
num_heads: the number of attention heads.
head_size: per-head hidden size.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
"""
def __init__(self, def __init__(self,
num_heads, num_heads,
...@@ -41,7 +56,7 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -41,7 +56,7 @@ class DocAttention(tf.keras.layers.Layer):
kernel_constraint=None, kernel_constraint=None,
bias_constraint=None, bias_constraint=None,
**kwargs): **kwargs):
super(DocAttention, self).__init__(**kwargs) super(VotingAttention, self).__init__(**kwargs)
self._num_heads = num_heads self._num_heads = num_heads
self._head_size = head_size self._head_size = head_size
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
...@@ -52,7 +67,7 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -52,7 +67,7 @@ class DocAttention(tf.keras.layers.Layer):
self._bias_constraint = tf.keras.constraints.get(bias_constraint) self._bias_constraint = tf.keras.constraints.get(bias_constraint)
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
self._query_dense = layers.DenseEinsum( self._query_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size), output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
...@@ -63,7 +78,7 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -63,7 +78,7 @@ class DocAttention(tf.keras.layers.Layer):
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
dtype=self.dtype, dtype=self.dtype,
name="encdocatt_query") name="encdocatt_query")
self._key_dense = layers.DenseEinsum( self._key_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size), output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer, kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer, bias_initializer=self._bias_initializer,
...@@ -74,7 +89,7 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -74,7 +89,7 @@ class DocAttention(tf.keras.layers.Layer):
bias_constraint=self._bias_constraint, bias_constraint=self._bias_constraint,
dtype=self.dtype, dtype=self.dtype,
name="encdocatt_key") name="encdocatt_key")
super(DocAttention, self).build(unused_input_shapes) super(VotingAttention, self).build(unused_input_shapes)
def call(self, encoder_outputs, doc_attention_mask): def call(self, encoder_outputs, doc_attention_mask):
num_docs = tf_utils.get_shape_list(encoder_outputs, expected_rank=[4])[1] num_docs = tf_utils.get_shape_list(encoder_outputs, expected_rank=[4])[1]
...@@ -95,12 +110,16 @@ class DocAttention(tf.keras.layers.Layer): ...@@ -95,12 +110,16 @@ class DocAttention(tf.keras.layers.Layer):
return tf.nn.softmax(doc_attention_probs + infadder) return tf.nn.softmax(doc_attention_probs + infadder)
class MultiChannelAttention(layers.MultiHeadAttention): class MultiChannelAttention(attention.MultiHeadAttention):
"""Multi-channel Attention layer.""" """Multi-channel Attention layer.
Introduced in: https://arxiv.org/abs/2001.09386. Expects multiple
cross-attention target sequences.
"""
def build(self, input_shape): def _build_attention(self, qkv_rank):
super(MultiChannelAttention, self).build(input_shape) super(MultiChannelAttention, self)._build_attention(qkv_rank)
self._masked_softmax = layers.MaskedSoftmax(mask_expansion_axes=[2]) self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[2])
def call(self, inputs, attention_mask=None): def call(self, inputs, attention_mask=None):
from_tensor = inputs[0] from_tensor = inputs[0]
......
...@@ -22,14 +22,15 @@ from __future__ import print_function ...@@ -22,14 +22,15 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.nlp.nhnet import multi_channel_attention from official.nlp.modeling.layers import multi_channel_attention
class MultiChannelAttentionTest(tf.test.TestCase): class MultiChannelAttentionTest(tf.test.TestCase):
def test_doc_attention(self): def test_doc_attention(self):
num_heads = 2 num_heads = 2
doc_attention = multi_channel_attention.DocAttention(num_heads, head_size=8) doc_attention = multi_channel_attention.VotingAttention(
num_heads, head_size=8)
num_docs = 3 num_docs = 3
inputs = np.zeros((2, num_docs, 10, 16), dtype=np.float32) inputs = np.zeros((2, num_docs, 10, 16), dtype=np.float32)
doc_mask = np.zeros((2, num_docs), dtype=np.float32) doc_mask = np.zeros((2, num_docs), dtype=np.float32)
......
This diff is collapsed.
...@@ -215,5 +215,39 @@ class TransformerLayerTest(keras_parameterized.TestCase): ...@@ -215,5 +215,39 @@ class TransformerLayerTest(keras_parameterized.TestCase):
self.assertAllEqual([1, input_length, width], output_data.shape) self.assertAllEqual([1, input_length, width], output_data.shape)
def _create_cache(batch_size, init_decode_length, num_heads, head_size):
return {
'key':
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32),
'value':
tf.zeros([batch_size, init_decode_length, num_heads, head_size],
dtype=tf.float32)
}
@keras_parameterized.run_all_keras_modes
class TransformerDecoderLayerTest(keras_parameterized.TestCase):
def test_decoder_block_with_cache(self):
num_attention_heads = 2
hidden_size = 16
decoder_block = transformer.TransformerDecoderLayer(
num_attention_heads=num_attention_heads,
intermediate_size=32,
intermediate_activation='relu',
dropout_rate=0.1,
attention_dropout_rate=0.1)
# Forward path.
dummy_tensor = tf.zeros([2, 4, 16], dtype=tf.float32)
dummy_mask = tf.zeros([2, 4, 4], dtype=tf.float32)
inputs = [dummy_tensor, dummy_tensor, dummy_mask, dummy_mask]
cache = _create_cache(2, 0, num_attention_heads,
hidden_size // num_attention_heads)
output, cache = decoder_block(inputs, cache)
self.assertEqual(output.shape, (2, 4, hidden_size))
self.assertEqual(cache['value'].shape, (2, 4, 2, 8))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -4,6 +4,3 @@ Losses contains common loss computation used in NLP tasks. ...@@ -4,6 +4,3 @@ Losses contains common loss computation used in NLP tasks.
* `weighted_sparse_categorical_crossentropy_loss` computes per-batch sparse * `weighted_sparse_categorical_crossentropy_loss` computes per-batch sparse
categorical crossentropy loss. categorical crossentropy loss.
* `weighted_sparse_categorical_crossentropy_per_example_loss` computes
per-example sparse categorical crossentropy loss.
...@@ -14,4 +14,3 @@ ...@@ -14,4 +14,3 @@
# ============================================================================== # ==============================================================================
"""Activations package definition. Subject to change.""" """Activations package definition. Subject to change."""
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import loss as weighted_sparse_categorical_crossentropy_loss from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import loss as weighted_sparse_categorical_crossentropy_loss
from official.nlp.modeling.losses.weighted_sparse_categorical_crossentropy import per_example_loss as weighted_sparse_categorical_crossentropy_per_example_loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment