Commit 444e8c79 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Update XLNet classifier:

- Support first/last summary type
- Support bert format input processing

PiperOrigin-RevId: 276005310
parent 9616954e
......@@ -13,18 +13,12 @@
# limitations under the License.
# ==============================================================================
"""Utilities for pre-processing classification data."""
from absl import flags
from absl import logging
from official.nlp.xlnet import data_utils
FLAGS = flags.FLAGS
SEG_ID_A = 0
SEG_ID_B = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4
class PaddingInputExample(object):
......@@ -72,8 +66,8 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
tokens_b.pop()
def convert_single_example(ex_index, example, label_list, max_seq_length,
tokenize_fn):
def convert_single_example(example_index, example, label_list, max_seq_length,
tokenize_fn, use_bert_format):
"""Converts a single `InputExample` into a single `InputFeatures`."""
if isinstance(example, PaddingInputExample):
......@@ -119,8 +113,12 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_B)
tokens.append(data_utils.CLS_ID)
segment_ids.append(SEG_ID_CLS)
if use_bert_format:
tokens.insert(0, data_utils.CLS_ID)
segment_ids.insert(0, data_utils.SEG_ID_CLS)
else:
tokens.append(data_utils.CLS_ID)
segment_ids.append(data_utils.SEG_ID_CLS)
input_ids = tokens
......@@ -131,9 +129,14 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
# Zero-pad up to the sequence length.
if len(input_ids) < max_seq_length:
delta_len = max_seq_length - len(input_ids)
input_ids = [0] * delta_len + input_ids
input_mask = [1] * delta_len + input_mask
segment_ids = [SEG_ID_PAD] * delta_len + segment_ids
if use_bert_format:
input_ids = input_ids + [0] * delta_len
input_mask = input_mask + [1] * delta_len
segment_ids = segment_ids + [data_utils.SEG_ID_PAD] * delta_len
else:
input_ids = [0] * delta_len + input_ids
input_mask = [1] * delta_len + input_mask
segment_ids = [data_utils.SEG_ID_PAD] * delta_len + segment_ids
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
......@@ -143,7 +146,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
label_id = label_map[example.label]
else:
label_id = example.label
if ex_index < 5:
if example_index < 5:
logging.info("*** Example ***")
logging.info("guid: %s", (example.guid))
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
......
......@@ -55,6 +55,10 @@ flags.DEFINE_integer(
flags.DEFINE_bool("uncased", default=False, help="Use uncased.")
flags.DEFINE_bool(
"is_regression", default=False, help="Whether it's a regression task.")
flags.DEFINE_bool(
"use_bert_format",
default=False,
help="Whether to use BERT format to arrange input data.")
FLAGS = flags.FLAGS
......@@ -356,7 +360,8 @@ def file_based_convert_examples_to_features(examples,
feature = classifier_utils.convert_single_example(ex_index, example,
label_list,
max_seq_length,
tokenize_fn)
tokenize_fn,
FLAGS.use_bert_format)
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
......
......@@ -36,13 +36,20 @@ from official.nlp.xlnet import training_utils
from official.utils.misc import tpu_lib
flags.DEFINE_integer("n_class", default=2, help="Number of classes.")
flags.DEFINE_string(
"summary_type",
default="last",
help="Method used to summarize a sequence into a vector.")
FLAGS = flags.FLAGS
def get_classificationxlnet_model(model_config, run_config, n_class):
def get_classificationxlnet_model(model_config,
run_config,
n_class,
summary_type="last"):
model = modeling.ClassificationXLNetModel(
model_config, run_config, n_class, name="model")
model_config, run_config, n_class, summary_type, name="model")
return model
......@@ -65,6 +72,7 @@ def run_evaluation(strategy,
them when calculating the accuracy. For the reason that there will be
dynamic-shape tensor, we first collect logits, labels and masks from TPU
and calculate the accuracy via numpy locally.
Returns:
A float metric, accuracy.
"""
......@@ -159,7 +167,7 @@ def main(unused_argv):
model_config = xlnet_config.XLNetConfig(FLAGS)
run_config = xlnet_config.create_run_config(True, False, FLAGS)
model_fn = functools.partial(get_classificationxlnet_model, model_config,
run_config, FLAGS.n_class)
run_config, FLAGS.n_class, FLAGS.summary_type)
input_meta_data = {}
input_meta_data["d_model"] = FLAGS.d_model
input_meta_data["mem_len"] = FLAGS.mem_len
......
......@@ -888,7 +888,7 @@ class ClassificationXLNetModel(tf.keras.Model):
"""
def __init__(self, xlnet_config, run_config, n_class, **kwargs):
def __init__(self, xlnet_config, run_config, n_class, summary_type, **kwargs):
super(ClassificationXLNetModel, self).__init__(**kwargs)
self.run_config = run_config
self.initializer = _get_initializer(run_config)
......@@ -924,7 +924,7 @@ class ClassificationXLNetModel(tf.keras.Model):
dropout_att=self.run_config.dropout_att,
initializer=self.initializer,
use_proj=True,
summary_type='last',
summary_type=summary_type,
name='sequence_summary')
self.cl_loss_layer = ClassificationLossLayer(
......@@ -946,9 +946,9 @@ class ClassificationXLNetModel(tf.keras.Model):
self.transformerxl_model(
inp_k=input_ids, seg_id=seg_ids, input_mask=input_mask, mems=mems))
self.summary = self.summarization_layer(transformerxl_output)
summary = self.summarization_layer(transformerxl_output)
per_example_loss, logits = self.cl_loss_layer(
hidden=self.summary, labels=label)
hidden=summary, labels=label)
self.add_loss(tf.keras.backend.mean(per_example_loss))
return new_mems, logits
......@@ -1087,7 +1087,12 @@ class Summarization(tf.keras.layers.Layer):
def call(self, inputs):
"""Implements call() for the layer."""
summary = inputs[-1]
if self.summary_type == 'last':
summary = inputs[-1]
elif self.summary_type == 'first':
summary = inputs[0]
else:
raise ValueError('Invalid summary type provided: %s' % self.summary_type)
summary = self.proj_layer(summary)
summary = self.dropout_layer(summary)
return summary
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment