"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6f877d9daf36788bad4fd228930939fed6ab12bd"
Commit dc4e9e5c authored by Lysandre's avatar Lysandre
Browse files

DataParallel for SQuAD + fix XLM

parent e6cff60b
...@@ -299,10 +299,14 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -299,10 +299,14 @@ def evaluate(args, model, tokenizer, prefix=""):
# XLNet and XLM use a more complex post-processing procedure # XLNet and XLM use a more complex post-processing procedure
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ['xlnet', 'xlm']:
start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size, predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
args.max_answer_length, output_prediction_file, args.max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file, output_nbest_file, output_null_log_odds_file,
model.config.start_n_top, model.config.end_n_top, start_n_top, end_n_top,
args.version_2_with_negative, tokenizer, args.verbose_logging) args.version_2_with_negative, tokenizer, args.verbose_logging)
else: else:
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size, predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
......
...@@ -695,7 +695,12 @@ def compute_predictions_log_probs( ...@@ -695,7 +695,12 @@ def compute_predictions_log_probs(
tok_text = " ".join(tok_text.split()) tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens) orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case, if hasattr(tokenizer, "do_lower_case"):
do_lower_case = tokenizer.do_lower_case
else:
do_lower_case = tokenizer.do_lowercase_and_remove_accent
final_text = get_final_text(tok_text, orig_text, do_lower_case,
verbose_logging) verbose_logging)
if final_text in seen_predictions: if final_text in seen_predictions:
......
...@@ -549,6 +549,10 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -549,6 +549,10 @@ class XLMTokenizer(PreTrainedTokenizer):
additional_special_tokens=additional_special_tokens, additional_special_tokens=additional_special_tokens,
**kwargs) **kwargs)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
# cache of sm.MosesPunctNormalizer instance # cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict() self.cache_moses_punct_normalizer = dict()
# cache of sm.MosesTokenizer instance # cache of sm.MosesTokenizer instance
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment