Commit 9af479b3 authored by thomwolf's avatar thomwolf
Browse files

conversion run_squad ok

parent 8e81e5e6
...@@ -27,7 +27,6 @@ import modeling ...@@ -27,7 +27,6 @@ import modeling
import optimization import optimization
import tokenization import tokenization
import six import six
import tensorflow as tf
import argparse import argparse
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
...@@ -177,7 +176,7 @@ class InputFeatures(object): ...@@ -177,7 +176,7 @@ class InputFeatures(object):
def read_squad_examples(input_file, is_training): def read_squad_examples(input_file, is_training):
"""Read a SQuAD json file into a list of SquadExample.""" """Read a SQuAD json file into a list of SquadExample."""
with tf.gfile.Open(input_file, "r") as reader: with open(input_file, "r") as reader:
input_data = json.load(reader)["data"] input_data = json.load(reader)["data"]
def is_whitespace(c): def is_whitespace(c):
...@@ -229,7 +228,7 @@ def read_squad_examples(input_file, is_training): ...@@ -229,7 +228,7 @@ def read_squad_examples(input_file, is_training):
cleaned_answer_text = " ".join( cleaned_answer_text = " ".join(
tokenization.whitespace_tokenize(orig_answer_text)) tokenization.whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1: if actual_text.find(cleaned_answer_text) == -1:
tf.logging.warning("Could not find answer: '%s' vs. '%s'", logger.warning("Could not find answer: '%s' vs. '%s'",
actual_text, cleaned_answer_text) actual_text, cleaned_answer_text)
continue continue
...@@ -356,27 +355,27 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -356,27 +355,27 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
end_position = tok_end_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset
if example_index < 20: if example_index < 20:
tf.logging.info("*** Example ***") logger.info("*** Example ***")
tf.logging.info("unique_id: %s" % (unique_id)) logger.info("unique_id: %s" % (unique_id))
tf.logging.info("example_index: %s" % (example_index)) logger.info("example_index: %s" % (example_index))
tf.logging.info("doc_span_index: %s" % (doc_span_index)) logger.info("doc_span_index: %s" % (doc_span_index))
tf.logging.info("tokens: %s" % " ".join( logger.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in tokens])) [tokenization.printable_text(x) for x in tokens]))
tf.logging.info("token_to_orig_map: %s" % " ".join( logger.info("token_to_orig_map: %s" % " ".join(
["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)])) ["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)]))
tf.logging.info("token_is_max_context: %s" % " ".join([ logger.info("token_is_max_context: %s" % " ".join([
"%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context) "%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context)
])) ]))
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
tf.logging.info( logger.info(
"input_mask: %s" % " ".join([str(x) for x in input_mask])) "input_mask: %s" % " ".join([str(x) for x in input_mask]))
tf.logging.info( logger.info(
"segment_ids: %s" % " ".join([str(x) for x in segment_ids])) "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
if is_training: if is_training:
answer_text = " ".join(tokens[start_position:(end_position + 1)]) answer_text = " ".join(tokens[start_position:(end_position + 1)])
tf.logging.info("start_position: %d" % (start_position)) logger.info("start_position: %d" % (start_position))
tf.logging.info("end_position: %d" % (end_position)) logger.info("end_position: %d" % (end_position))
tf.logging.info( logger.info(
"answer: %s" % (tokenization.printable_text(answer_text))) "answer: %s" % (tokenization.printable_text(answer_text)))
features.append( features.append(
...@@ -471,207 +470,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position): ...@@ -471,207 +470,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index return cur_span_index == best_span_index
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids,
use_one_hot_embeddings):
"""Creates a classification model."""
model = modeling.BertModel(
config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
token_type_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
final_hidden = model.get_sequence_output()
final_hidden_shape = modeling.get_shape_list(final_hidden, expected_rank=3)
batch_size = final_hidden_shape[0]
seq_length = final_hidden_shape[1]
hidden_size = final_hidden_shape[2]
output_weights = tf.get_variable(
"cls/squad/output_weights", [2, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(
"cls/squad/output_bias", [2], initializer=tf.zeros_initializer())
final_hidden_matrix = tf.reshape(final_hidden,
[batch_size * seq_length, hidden_size])
logits = tf.matmul(final_hidden_matrix, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
logits = tf.reshape(logits, [batch_size, seq_length, 2])
logits = tf.transpose(logits, [2, 0, 1])
unstacked_logits = tf.unstack(logits, axis=0)
(start_logits, end_logits) = (unstacked_logits[0], unstacked_logits[1])
return (start_logits, end_logits)
def model_fn_builder(bert_config, init_checkpoint, learning_rate,
num_train_steps, num_warmup_steps, use_tpu,
use_one_hot_embeddings):
"""Returns `model_fn` closure for TPUEstimator."""
def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
tf.logging.info("*** Features ***")
for name in sorted(features.keys()):
tf.logging.info(" name = %s, shape = %s" % (name, features[name].shape))
unique_ids = features["unique_ids"]
input_ids = features["input_ids"]
input_mask = features["input_mask"]
segment_ids = features["segment_ids"]
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
(start_logits, end_logits) = create_model(
bert_config=bert_config,
is_training=is_training,
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
use_one_hot_embeddings=use_one_hot_embeddings)
tvars = tf.trainable_variables()
initialized_variable_names = {}
scaffold_fn = None
if init_checkpoint:
(assignment_map,
initialized_variable_names) = modeling.get_assigment_map_from_checkpoint(
tvars, init_checkpoint)
if use_tpu:
def tpu_scaffold():
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
return tf.train.Scaffold()
scaffold_fn = tpu_scaffold
else:
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
tf.logging.info("**** Trainable Variables ****")
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
tf.logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)
output_spec = None
if mode == tf.estimator.ModeKeys.TRAIN:
seq_length = modeling.get_shape_list(input_ids)[1]
def compute_loss(logits, positions):
one_hot_positions = tf.one_hot(
positions, depth=seq_length, dtype=tf.float32)
log_probs = tf.nn.log_softmax(logits, axis=-1)
loss = -tf.reduce_mean(
tf.reduce_sum(one_hot_positions * log_probs, axis=-1))
return loss
start_positions = features["start_positions"]
end_positions = features["end_positions"]
start_loss = compute_loss(start_logits, start_positions)
end_loss = compute_loss(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2.0
train_op = optimization.create_optimizer(
total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu)
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode,
loss=total_loss,
train_op=train_op,
scaffold_fn=scaffold_fn)
elif mode == tf.estimator.ModeKeys.PREDICT:
predictions = {
"unique_ids": unique_ids,
"start_logits": start_logits,
"end_logits": end_logits,
}
output_spec = tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
else:
raise ValueError(
"Only TRAIN and PREDICT modes are supported: %s" % (mode))
return output_spec
return model_fn
def input_fn_builder(features, seq_length, is_training, drop_remainder):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
all_unique_ids = []
all_input_ids = []
all_input_mask = []
all_segment_ids = []
all_start_positions = []
all_end_positions = []
for feature in features:
all_unique_ids.append(feature.unique_id)
all_input_ids.append(feature.input_ids)
all_input_mask.append(feature.input_mask)
all_segment_ids.append(feature.segment_ids)
if is_training:
all_start_positions.append(feature.start_position)
all_end_positions.append(feature.end_position)
def input_fn(params):
"""The actual input function."""
batch_size = params["batch_size"]
num_examples = len(features)
# This is for demo purposes and does NOT scale to large data sets. We do
# not use Dataset.from_generator() because that uses tf.py_func which is
# not TPU compatible. The right way to load data is with TFRecordReader.
feature_map = {
"unique_ids":
tf.constant(all_unique_ids, shape=[num_examples], dtype=tf.int32),
"input_ids":
tf.constant(
all_input_ids, shape=[num_examples, seq_length],
dtype=tf.int32),
"input_mask":
tf.constant(
all_input_mask,
shape=[num_examples, seq_length],
dtype=tf.int32),
"segment_ids":
tf.constant(
all_segment_ids,
shape=[num_examples, seq_length],
dtype=tf.int32),
}
if is_training:
feature_map["start_positions"] = tf.constant(
all_start_positions, shape=[num_examples], dtype=tf.int32)
feature_map["end_positions"] = tf.constant(
all_end_positions, shape=[num_examples], dtype=tf.int32)
d = tf.data.Dataset.from_tensor_slices(feature_map)
if is_training:
d = d.repeat()
d = d.shuffle(buffer_size=100)
d = d.batch(batch_size=batch_size, drop_remainder=drop_remainder)
return d
return input_fn
RawResult = collections.namedtuple("RawResult", RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"]) ["unique_id", "start_logits", "end_logits"])
...@@ -681,8 +479,8 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -681,8 +479,8 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file, max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file): output_nbest_file):
"""Write final predictions to the json file.""" """Write final predictions to the json file."""
tf.logging.info("Writing predictions to: %s" % (output_prediction_file)) logger.info("Writing predictions to: %s" % (output_prediction_file))
tf.logging.info("Writing nbest to: %s" % (output_nbest_file)) logger.info("Writing nbest to: %s" % (output_nbest_file))
example_index_to_features = collections.defaultdict(list) example_index_to_features = collections.defaultdict(list)
for feature in all_features: for feature in all_features:
...@@ -804,10 +602,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -804,10 +602,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
all_predictions[example.qas_id] = nbest_json[0]["text"] all_predictions[example.qas_id] = nbest_json[0]["text"]
all_nbest_json[example.qas_id] = nbest_json all_nbest_json[example.qas_id] = nbest_json
with tf.gfile.GFile(output_prediction_file, "w") as writer: with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n") writer.write(json.dumps(all_predictions, indent=4) + "\n")
with tf.gfile.GFile(output_nbest_file, "w") as writer: with open(output_nbest_file, "w") as writer:
writer.write(json.dumps(all_nbest_json, indent=4) + "\n") writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
...@@ -861,7 +659,7 @@ def get_final_text(pred_text, orig_text, do_lower_case): ...@@ -861,7 +659,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
start_position = tok_text.find(pred_text) start_position = tok_text.find(pred_text)
if start_position == -1: if start_position == -1:
if args.verbose_logging: if args.verbose_logging:
tf.logging.info( logger.info(
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) "Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
return orig_text return orig_text
end_position = start_position + len(pred_text) - 1 end_position = start_position + len(pred_text) - 1
...@@ -871,7 +669,7 @@ def get_final_text(pred_text, orig_text, do_lower_case): ...@@ -871,7 +669,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
if len(orig_ns_text) != len(tok_ns_text): if len(orig_ns_text) != len(tok_ns_text):
if args.verbose_logging: if args.verbose_logging:
tf.logging.info("Length not equal after stripping spaces: '%s' vs '%s'", logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
orig_ns_text, tok_ns_text) orig_ns_text, tok_ns_text)
return orig_text return orig_text
...@@ -889,7 +687,7 @@ def get_final_text(pred_text, orig_text, do_lower_case): ...@@ -889,7 +687,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
if orig_start_position is None: if orig_start_position is None:
if args.verbose_logging: if args.verbose_logging:
tf.logging.info("Couldn't map start position") logger.info("Couldn't map start position")
return orig_text return orig_text
orig_end_position = None orig_end_position = None
...@@ -900,7 +698,7 @@ def get_final_text(pred_text, orig_text, do_lower_case): ...@@ -900,7 +698,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
if orig_end_position is None: if orig_end_position is None:
if args.verbose_logging: if args.verbose_logging:
tf.logging.info("Couldn't map end position") logger.info("Couldn't map end position")
return orig_text return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)] output_text = orig_text[orig_start_position:(orig_end_position + 1)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment