"torchvision/csrc/vscode:/vscode.git/clone" did not exist on "52b8685bde554501604471337a11578fdf026027"
Commit 0dc9a20a authored by Allen Wang's avatar Allen Wang Committed by A. Unique TensorFlower
Browse files

Fix XLNet Classifier.

PiperOrigin-RevId: 334642892
parent 79354e14
......@@ -153,7 +153,7 @@ def convert_single_example(example_index, example, label_list, max_seq_length,
logging.info("input_ids: %s", " ".join([str(x) for x in input_ids]))
logging.info("input_mask: %s", " ".join([str(x) for x in input_mask]))
logging.info("segment_ids: %s", " ".join([str(x) for x in segment_ids]))
logging.info("label: %d (id = %d)", example.label, label_id)
logging.info("label: %s (id = %d)", example.label, label_id)
feature = InputFeatures(
input_ids=input_ids,
......
......@@ -155,7 +155,7 @@ def main(unused_argv):
adam_epsilon=FLAGS.adam_epsilon)
model_config = xlnet_config.XLNetConfig(FLAGS)
run_config = xlnet_config.create_run_config(True, False, FLAGS)
model_fn = functools.partial(modeling.classification_model, model_config,
model_fn = functools.partial(get_classificationxlnet_model, model_config,
run_config, FLAGS.n_class, FLAGS.summary_type)
input_meta_data = {}
input_meta_data["d_model"] = FLAGS.d_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment