"tests/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "2f88c1249bf13f6e979dda3202f883936b382d95"
Commit 444e8c79 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Update XLNet classifier:

- Support first/last summary type
- Support bert format input processing

PiperOrigin-RevId: 276005310
parent 9616954e
...@@ -13,18 +13,12 @@ ...@@ -13,18 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Utilities for pre-processing classification data.""" """Utilities for pre-processing classification data."""
from absl import flags
from absl import logging from absl import logging
from official.nlp.xlnet import data_utils from official.nlp.xlnet import data_utils
FLAGS = flags.FLAGS
SEG_ID_A = 0 SEG_ID_A = 0
SEG_ID_B = 1 SEG_ID_B = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4
class PaddingInputExample(object): class PaddingInputExample(object):
...@@ -72,8 +66,8 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): ...@@ -72,8 +66,8 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
tokens_b.pop() tokens_b.pop()
def convert_single_example(ex_index, example, label_list, max_seq_length, def convert_single_example(example_index, example, label_list, max_seq_length,
tokenize_fn): tokenize_fn, use_bert_format):
"""Converts a single `InputExample` into a single `InputFeatures`.""" """Converts a single `InputExample` into a single `InputFeatures`."""
if isinstance(example, PaddingInputExample): if isinstance(example, PaddingInputExample):
...@@ -119,8 +113,12 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -119,8 +113,12 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
tokens.append(data_utils.SEP_ID) tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_B) segment_ids.append(SEG_ID_B)
tokens.append(data_utils.CLS_ID) if use_bert_format:
segment_ids.append(SEG_ID_CLS) tokens.insert(0, data_utils.CLS_ID)
segment_ids.insert(0, data_utils.SEG_ID_CLS)
else:
tokens.append(data_utils.CLS_ID)
segment_ids.append(data_utils.SEG_ID_CLS)
input_ids = tokens input_ids = tokens
...@@ -131,9 +129,14 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -131,9 +129,14 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
# Zero-pad up to the sequence length. # Zero-pad up to the sequence length.
if len(input_ids) < max_seq_length: if len(input_ids) < max_seq_length:
delta_len = max_seq_length - len(input_ids) delta_len = max_seq_length - len(input_ids)
input_ids = [0] * delta_len + input_ids if use_bert_format:
input_mask = [1] * delta_len + input_mask input_ids = input_ids + [0] * delta_len
segment_ids = [SEG_ID_PAD] * delta_len + segment_ids input_mask = input_mask + [1] * delta_len
segment_ids = segment_ids + [data_utils.SEG_ID_PAD] * delta_len
else:
input_ids = [0] * delta_len + input_ids
input_mask = [1] * delta_len + input_mask
segment_ids = [data_utils.SEG_ID_PAD] * delta_len + segment_ids
assert len(input_ids) == max_seq_length assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length assert len(input_mask) == max_seq_length
...@@ -143,7 +146,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length, ...@@ -143,7 +146,7 @@ def convert_single_example(ex_index, example, label_list, max_seq_length,
label_id = label_map[example.label] label_id = label_map[example.label]
else: else:
label_id = example.label label_id = example.label
if ex_index < 5: if example_index < 5:
logging.info("*** Example ***") logging.info("*** Example ***")
logging.info("guid: %s", (example.guid)) logging.info("guid: %s", (example.guid))
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids])) logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
......
...@@ -55,6 +55,10 @@ flags.DEFINE_integer( ...@@ -55,6 +55,10 @@ flags.DEFINE_integer(
flags.DEFINE_bool("uncased", default=False, help="Use uncased.") flags.DEFINE_bool("uncased", default=False, help="Use uncased.")
flags.DEFINE_bool( flags.DEFINE_bool(
"is_regression", default=False, help="Whether it's a regression task.") "is_regression", default=False, help="Whether it's a regression task.")
flags.DEFINE_bool(
"use_bert_format",
default=False,
help="Whether to use BERT format to arrange input data.")
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -356,7 +360,8 @@ def file_based_convert_examples_to_features(examples, ...@@ -356,7 +360,8 @@ def file_based_convert_examples_to_features(examples,
feature = classifier_utils.convert_single_example(ex_index, example, feature = classifier_utils.convert_single_example(ex_index, example,
label_list, label_list,
max_seq_length, max_seq_length,
tokenize_fn) tokenize_fn,
FLAGS.use_bert_format)
def create_int_feature(values): def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
......
...@@ -36,13 +36,20 @@ from official.nlp.xlnet import training_utils ...@@ -36,13 +36,20 @@ from official.nlp.xlnet import training_utils
from official.utils.misc import tpu_lib from official.utils.misc import tpu_lib
flags.DEFINE_integer("n_class", default=2, help="Number of classes.") flags.DEFINE_integer("n_class", default=2, help="Number of classes.")
flags.DEFINE_string(
"summary_type",
default="last",
help="Method used to summarize a sequence into a vector.")
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def get_classificationxlnet_model(model_config, run_config, n_class): def get_classificationxlnet_model(model_config,
run_config,
n_class,
summary_type="last"):
model = modeling.ClassificationXLNetModel( model = modeling.ClassificationXLNetModel(
model_config, run_config, n_class, name="model") model_config, run_config, n_class, summary_type, name="model")
return model return model
...@@ -65,6 +72,7 @@ def run_evaluation(strategy, ...@@ -65,6 +72,7 @@ def run_evaluation(strategy,
them when calculating the accuracy. For the reason that there will be them when calculating the accuracy. For the reason that there will be
dynamic-shape tensor, we first collect logits, labels and masks from TPU dynamic-shape tensor, we first collect logits, labels and masks from TPU
and calculate the accuracy via numpy locally. and calculate the accuracy via numpy locally.
Returns: Returns:
A float metric, accuracy. A float metric, accuracy.
""" """
...@@ -159,7 +167,7 @@ def main(unused_argv): ...@@ -159,7 +167,7 @@ def main(unused_argv):
model_config = xlnet_config.XLNetConfig(FLAGS) model_config = xlnet_config.XLNetConfig(FLAGS)
run_config = xlnet_config.create_run_config(True, False, FLAGS) run_config = xlnet_config.create_run_config(True, False, FLAGS)
model_fn = functools.partial(get_classificationxlnet_model, model_config, model_fn = functools.partial(get_classificationxlnet_model, model_config,
run_config, FLAGS.n_class) run_config, FLAGS.n_class, FLAGS.summary_type)
input_meta_data = {} input_meta_data = {}
input_meta_data["d_model"] = FLAGS.d_model input_meta_data["d_model"] = FLAGS.d_model
input_meta_data["mem_len"] = FLAGS.mem_len input_meta_data["mem_len"] = FLAGS.mem_len
......
...@@ -888,7 +888,7 @@ class ClassificationXLNetModel(tf.keras.Model): ...@@ -888,7 +888,7 @@ class ClassificationXLNetModel(tf.keras.Model):
""" """
def __init__(self, xlnet_config, run_config, n_class, **kwargs): def __init__(self, xlnet_config, run_config, n_class, summary_type, **kwargs):
super(ClassificationXLNetModel, self).__init__(**kwargs) super(ClassificationXLNetModel, self).__init__(**kwargs)
self.run_config = run_config self.run_config = run_config
self.initializer = _get_initializer(run_config) self.initializer = _get_initializer(run_config)
...@@ -924,7 +924,7 @@ class ClassificationXLNetModel(tf.keras.Model): ...@@ -924,7 +924,7 @@ class ClassificationXLNetModel(tf.keras.Model):
dropout_att=self.run_config.dropout_att, dropout_att=self.run_config.dropout_att,
initializer=self.initializer, initializer=self.initializer,
use_proj=True, use_proj=True,
summary_type='last', summary_type=summary_type,
name='sequence_summary') name='sequence_summary')
self.cl_loss_layer = ClassificationLossLayer( self.cl_loss_layer = ClassificationLossLayer(
...@@ -946,9 +946,9 @@ class ClassificationXLNetModel(tf.keras.Model): ...@@ -946,9 +946,9 @@ class ClassificationXLNetModel(tf.keras.Model):
self.transformerxl_model( self.transformerxl_model(
inp_k=input_ids, seg_id=seg_ids, input_mask=input_mask, mems=mems)) inp_k=input_ids, seg_id=seg_ids, input_mask=input_mask, mems=mems))
self.summary = self.summarization_layer(transformerxl_output) summary = self.summarization_layer(transformerxl_output)
per_example_loss, logits = self.cl_loss_layer( per_example_loss, logits = self.cl_loss_layer(
hidden=self.summary, labels=label) hidden=summary, labels=label)
self.add_loss(tf.keras.backend.mean(per_example_loss)) self.add_loss(tf.keras.backend.mean(per_example_loss))
return new_mems, logits return new_mems, logits
...@@ -1087,7 +1087,12 @@ class Summarization(tf.keras.layers.Layer): ...@@ -1087,7 +1087,12 @@ class Summarization(tf.keras.layers.Layer):
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
summary = inputs[-1] if self.summary_type == 'last':
summary = inputs[-1]
elif self.summary_type == 'first':
summary = inputs[0]
else:
raise ValueError('Invalid summary type provided: %s' % self.summary_type)
summary = self.proj_layer(summary) summary = self.proj_layer(summary)
summary = self.dropout_layer(summary) summary = self.dropout_layer(summary)
return summary return summary
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment