Commit b5d73976 authored by erenup's avatar erenup
Browse files

Revert "fixing for roberta tokenizer decoding"

This reverts commit 22e7c4ed.
parent 22e7c4ed
......@@ -263,7 +263,7 @@ def evaluate(args, model, tokenizer, prefix=""):
write_predictions(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold, tokenizer, args.model_type)
args.version_2_with_negative, args.null_score_diff_threshold)
# Evaluate with the official SQuAD script
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
......@@ -296,7 +296,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=not evaluate, add_prefix_space=True if args.model_type == 'roberta' else False)
is_training=not evaluate)
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
......
......@@ -25,7 +25,6 @@ import collections
from io import open
from transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
from transformers.tokenization_roberta import RobertaTokenizer
# Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method)
from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores
......@@ -193,7 +192,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
sequence_a_segment_id=0, sequence_b_segment_id=1,
cls_token_segment_id=0, pad_token_segment_id=0,
mask_padding_with_zero=True, add_prefix_space=False):
mask_padding_with_zero=True):
"""Loads a data file into a list of `InputBatch`s."""
unique_id = 1000000000
......@@ -206,7 +205,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
# if example_index % 100 == 0:
# logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg)
query_tokens = tokenizer.tokenize(example.question_text, add_prefix_space=add_prefix_space)
query_tokens = tokenizer.tokenize(example.question_text)
if len(query_tokens) > max_query_length:
query_tokens = query_tokens[0:max_query_length]
......@@ -216,7 +216,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
all_doc_tokens = []
for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token, add_prefix_space=add_prefix_space)
sub_tokens = tokenizer.tokenize(token)
for sub_token in sub_tokens:
tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token)
......@@ -234,7 +234,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_end_position = len(all_doc_tokens) - 1
(tok_start_position, tok_end_position) = _improve_answer_span(
all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
example.orig_answer_text, add_prefix_space)
example.orig_answer_text)
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
......@@ -398,7 +398,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
orig_answer_text, add_prefix_space):
orig_answer_text):
"""Returns tokenized answer spans that better match the annotated answer."""
# The SQuAD annotations are character based. We first project them to
......@@ -423,7 +423,7 @@ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
# the word "Japanese". Since our WordPiece tokenizer does not split
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
# in SQuAD, but does happen.
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text, add_prefix_space=add_prefix_space))
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1):
......@@ -477,7 +477,7 @@ RawResult = collections.namedtuple("RawResult",
def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, verbose_logging,
version_2_with_negative, null_score_diff_threshold, tokenizer, mode_type='bert'):
version_2_with_negative, null_score_diff_threshold):
"""Write final predictions to the json file and log-odds of null if needed."""
logger.info("Writing predictions to: %s" % (output_prediction_file))
logger.info("Writing nbest to: %s" % (output_nbest_file))
......@@ -576,13 +576,6 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off.
if mode_type == 'roberta':
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
tok_text = tok_text.replace("##", "")
tok_text = " ".join(tok_text.strip().split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging, None)
else:
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment