Commit 22e7c4ed authored by erenup's avatar erenup
Browse files

fixing for roberta tokenizer decoding

parent ebb32261
...@@ -263,7 +263,7 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -263,7 +263,7 @@ def evaluate(args, model, tokenizer, prefix=""):
write_predictions(examples, features, all_results, args.n_best_size, write_predictions(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file, args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging, output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold) args.version_2_with_negative, args.null_score_diff_threshold, tokenizer, args.model_type)
# Evaluate with the official SQuAD script # Evaluate with the official SQuAD script
evaluate_options = EVAL_OPTS(data_file=args.predict_file, evaluate_options = EVAL_OPTS(data_file=args.predict_file,
...@@ -296,7 +296,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -296,7 +296,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride, doc_stride=args.doc_stride,
max_query_length=args.max_query_length, max_query_length=args.max_query_length,
is_training=not evaluate) is_training=not evaluate, add_prefix_space=True if args.model_type == 'roberta' else False)
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save(features, cached_features_file)
......
...@@ -25,6 +25,7 @@ import collections ...@@ -25,6 +25,7 @@ import collections
from io import open from io import open
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
from transformers.tokenization_roberta import RobertaTokenizer
# Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method) # Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method)
from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores
...@@ -192,7 +193,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -192,7 +193,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
cls_token='[CLS]', sep_token='[SEP]', pad_token=0, cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
sequence_a_segment_id=0, sequence_b_segment_id=1, sequence_a_segment_id=0, sequence_b_segment_id=1,
cls_token_segment_id=0, pad_token_segment_id=0, cls_token_segment_id=0, pad_token_segment_id=0,
mask_padding_with_zero=True): mask_padding_with_zero=True, add_prefix_space=False):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
unique_id = 1000000000 unique_id = 1000000000
...@@ -205,8 +206,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -205,8 +206,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
# if example_index % 100 == 0: # if example_index % 100 == 0:
# logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg) # logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg)
query_tokens = tokenizer.tokenize(example.question_text, add_prefix_space=add_prefix_space)
query_tokens = tokenizer.tokenize(example.question_text)
if len(query_tokens) > max_query_length: if len(query_tokens) > max_query_length:
query_tokens = query_tokens[0:max_query_length] query_tokens = query_tokens[0:max_query_length]
...@@ -216,7 +216,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -216,7 +216,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
all_doc_tokens = [] all_doc_tokens = []
for (i, token) in enumerate(example.doc_tokens): for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens)) orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token) sub_tokens = tokenizer.tokenize(token, add_prefix_space=add_prefix_space)
for sub_token in sub_tokens: for sub_token in sub_tokens:
tok_to_orig_index.append(i) tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token) all_doc_tokens.append(sub_token)
...@@ -234,7 +234,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -234,7 +234,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_end_position = len(all_doc_tokens) - 1 tok_end_position = len(all_doc_tokens) - 1
(tok_start_position, tok_end_position) = _improve_answer_span( (tok_start_position, tok_end_position) = _improve_answer_span(
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
example.orig_answer_text) example.orig_answer_text, add_prefix_space)
# The -3 accounts for [CLS], [SEP] and [SEP] # The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
...@@ -398,7 +398,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -398,7 +398,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
orig_answer_text): orig_answer_text, add_prefix_space):
"""Returns tokenized answer spans that better match the annotated answer.""" """Returns tokenized answer spans that better match the annotated answer."""
# The SQuAD annotations are character based. We first project them to # The SQuAD annotations are character based. We first project them to
...@@ -423,7 +423,7 @@ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, ...@@ -423,7 +423,7 @@ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
# the word "Japanese". Since our WordPiece tokenizer does not split # the word "Japanese". Since our WordPiece tokenizer does not split
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
# in SQuAD, but does happen. # in SQuAD, but does happen.
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text, add_prefix_space=add_prefix_space))
for new_start in range(input_start, input_end + 1): for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1): for new_end in range(input_end, new_start - 1, -1):
...@@ -477,7 +477,7 @@ RawResult = collections.namedtuple("RawResult", ...@@ -477,7 +477,7 @@ RawResult = collections.namedtuple("RawResult",
def write_predictions(all_examples, all_features, all_results, n_best_size, def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file, max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, verbose_logging, output_nbest_file, output_null_log_odds_file, verbose_logging,
version_2_with_negative, null_score_diff_threshold): version_2_with_negative, null_score_diff_threshold, tokenizer, mode_type='bert'):
"""Write final predictions to the json file and log-odds of null if needed.""" """Write final predictions to the json file and log-odds of null if needed."""
logger.info("Writing predictions to: %s" % (output_prediction_file)) logger.info("Writing predictions to: %s" % (output_prediction_file))
logger.info("Writing nbest to: %s" % (output_nbest_file)) logger.info("Writing nbest to: %s" % (output_nbest_file))
...@@ -576,15 +576,22 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -576,15 +576,22 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
tok_text = " ".join(tok_tokens) tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off. # De-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "") if mode_type == 'roberta':
tok_text = tok_text.replace("##", "") tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
tok_text = tok_text.replace("##", "")
tok_text = " ".join(tok_text.strip().split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging, None)
else:
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
# Clean whitespace # Clean whitespace
tok_text = tok_text.strip() tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split()) tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens) orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
if final_text in seen_predictions: if final_text in seen_predictions:
continue continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment