Commit 0dc9a20a authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Fix XLNet Classifier.

PiperOrigin-RevId: 334642892
parent 79354e14
...@@ -153,7 +153,7 @@ def convert_single_example(example_index, example, label_list, max_seq_length, ...@@ -153,7 +153,7 @@ def convert_single_example(example_index, example, label_list, max_seq_length,
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids])) logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask])) logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids])) logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("label: %d (id = %d)", example.label, label_id) logging.info("label: %s (id = %d)", example.label, label_id)
feature = InputFeatures( feature = InputFeatures(
input_ids=input_ids, input_ids=input_ids,
......
...@@ -155,7 +155,7 @@ def main(unused_argv): ...@@ -155,7 +155,7 @@ def main(unused_argv):
adam_epsilon=FLAGS.adam_epsilon) adam_epsilon=FLAGS.adam_epsilon)
model_config = xlnet_config.XLNetConfig(FLAGS) model_config = xlnet_config.XLNetConfig(FLAGS)
run_config = xlnet_config.create_run_config(True, False, FLAGS) run_config = xlnet_config.create_run_config(True, False, FLAGS)
model_fn = functools.partial(modeling.classification_model, model_config, model_fn = functools.partial(get_classificationxlnet_model, model_config,
run_config, FLAGS.n_class, FLAGS.summary_type) run_config, FLAGS.n_class, FLAGS.summary_type)
input_meta_data = {} input_meta_data = {}
input_meta_data["d_model"] = FLAGS.d_model input_meta_data["d_model"] = FLAGS.d_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment