Commit 15d8b126 authored by thomwolf's avatar thomwolf
Browse files

update tokenizer - update squad example for xlnet

parent 3b469cb4
......@@ -242,7 +242,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
# Load data features from cache or dataset file
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
'dev' if evaluate else 'train',
list(filter(None, args.model_name.split('/'))).pop(),
list(filter(None, args.model_name_or_path.split('/'))).pop(),
str(args.max_seq_length),
str(task)))
if os.path.exists(cached_features_file):
......@@ -282,8 +282,10 @@ def main():
## Required parameters
parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--model_name", default=None, type=str, required=True,
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--task_name", default=None, type=str, required=True,
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
parser.add_argument("--output_dir", default=None, type=str, required=True,
......@@ -400,15 +402,11 @@ def main():
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = ""
for key in MODEL_CLASSES:
if key in args.model_name.lower():
args.model_type = key # take the first match in model types
break
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name, num_labels=num_labels, finetuning_task=args.task_name)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
......
......@@ -213,7 +213,6 @@ def evaluate(args, model, tokenizer, prefix=""):
inputs.update({'cls_index': batch[4],
'p_mask': batch[5]})
outputs = model(**inputs)
batch_start_logits, batch_end_logits = outputs[:2]
for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()]
......@@ -242,7 +241,8 @@ def evaluate(args, model, tokenizer, prefix=""):
write_predictions_extended(examples, features, all_results, args.n_best_size,
args.max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.predict_file,
args.start_n_top, args.end_n_top, args.version_2_with_negative)
model.config.start_n_top, model.config.end_n_top,
args.version_2_with_negative, tokenizer, args.verbose_logging)
else:
write_predictions(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file,
......@@ -262,7 +262,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
input_file = args.predict_file if evaluate else args.train_file
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
'dev' if evaluate else 'train',
list(filter(None, args.model_name.split('/'))).pop(),
list(filter(None, args.model_name_or_path.split('/'))).pop(),
str(args.max_seq_length)))
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
logger.info("Loading features from cached file %s", cached_features_file)
......@@ -312,8 +312,10 @@ def main():
help="SQuAD json for training. E.g., train-v1.1.json")
parser.add_argument("--predict_file", default=None, type=str, required=True,
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
parser.add_argument("--model_name", default=None, type=str, required=True,
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model checkpoints and predictions will be written.")
......@@ -438,15 +440,11 @@ def main():
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = ""
for key in MODEL_CLASSES:
if key in args.model_name.lower():
args.model_type = key # take the first match in model types
break
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
......
......@@ -60,8 +60,9 @@ class ExamplesTests(unittest.TestCase):
"--warmup_steps=2",
"--overwrite_output_dir",
"--seed=42"]
model_name = "--model_name=bert-base-uncased"
with patch.object(sys, 'argv', testargs + [model_name]):
model_type, model_name = ("--model_type=bert",
"--model_name_or_path=bert-base-uncased")
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
result = run_glue.main()
for value in result.values():
self.assertGreaterEqual(value, 0.75)
......@@ -85,8 +86,9 @@ class ExamplesTests(unittest.TestCase):
"--per_gpu_eval_batch_size=1",
"--overwrite_output_dir",
"--seed=42"]
model_name = "--model_name=bert-base-uncased"
with patch.object(sys, 'argv', testargs + [model_name]):
model_type, model_name = ("--model_type=bert",
"--model_name_or_path=bert-base-uncased")
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
result = run_squad.main()
self.assertGreaterEqual(result['f1'], 30)
self.assertGreaterEqual(result['exact'], 30)
......
......@@ -87,6 +87,7 @@ class InputFeatures(object):
segment_ids,
cls_index,
p_mask,
paragraph_len,
start_position=None,
end_position=None,
is_impossible=None):
......@@ -101,6 +102,7 @@ class InputFeatures(object):
self.segment_ids = segment_ids
self.cls_index = cls_index
self.p_mask = p_mask
self.paragraph_len = paragraph_len
self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
......@@ -292,6 +294,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(sequence_b_segment_id)
p_mask.append(0)
paragraph_len = doc_span.length
# SEP token
tokens.append(sep_token)
......@@ -385,6 +388,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
segment_ids=segment_ids,
cls_index=cls_index,
p_mask=p_mask,
paragraph_len=paragraph_len,
start_position=start_position,
end_position=end_position,
is_impossible=span_is_impossible))
......@@ -673,8 +677,9 @@ RawResultExtended = collections.namedtuple("RawResultExtended",
def write_predictions_extended(all_examples, all_features, all_results, n_best_size,
max_answer_length, output_prediction_file,
output_nbest_file,
output_null_log_odds_file, orig_data,
start_n_top, end_n_top, version_2_with_negative):
output_null_log_odds_file, orig_data_file,
start_n_top, end_n_top, version_2_with_negative,
tokenizer, verbose_logging):
""" XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed.
......@@ -764,13 +769,30 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
break
feature = features[pred.feature_index]
tok_start_to_orig_index = feature.tok_start_to_orig_index
tok_end_to_orig_index = feature.tok_end_to_orig_index
start_orig_pos = tok_start_to_orig_index[pred.start_index]
end_orig_pos = tok_end_to_orig_index[pred.end_index]
paragraph_text = example.paragraph_text
final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
# XLNet un-tokenizer
# Let's keep it simple for now and see if we need all this later.
#
# tok_start_to_orig_index = feature.tok_start_to_orig_index
# tok_end_to_orig_index = feature.tok_end_to_orig_index
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
# end_orig_pos = tok_end_to_orig_index[pred.end_index]
# paragraph_text = example.paragraph_text
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
# Previously used Bert untokenizer
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
verbose_logging)
if final_text in seen_predictions:
continue
......@@ -829,6 +851,9 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
with open(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
with open(orig_data_file, "r", encoding='utf-8') as reader:
orig_data = json.load(reader)["data"]
qid_to_has_ans = make_qid_to_has_ans(orig_data)
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
......
......@@ -528,9 +528,9 @@ class PoolerEndLogits(nn.Module):
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
1.0 means token should be masked.
"""
slen, hsz = hidden_states.shape[-2:]
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
if start_positions is not None:
slen, hsz = hidden_states.shape[-2:]
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
......@@ -571,7 +571,7 @@ class PoolerAnswerClass(nn.Module):
no dependency on end_feature so that we can obtain one single `cls_logits`
for each sample
"""
slen, hsz = hidden_states.shape[-2:]
hsz = hidden_states.shape[-1]
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
if start_positions is not None:
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
......@@ -614,12 +614,21 @@ class SQuADHead(nn.Module):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
**last_hidden_state**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) `torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
**start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
**start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
Indices for the top config.start_n_top start token possibilities (beam-search).
**end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
**end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
**cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size,)``
Log probabilities for the ``is_impossible`` label of the answers.
"""
def __init__(self, config):
super(SQuADHead, self).__init__()
......@@ -667,8 +676,8 @@ class SQuADHead(nn.Module):
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
......
......@@ -1167,12 +1167,23 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
1.0 means token should be masked. 0.0 mean token is not masked.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
**start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
**start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
Indices for the top config.start_n_top start token possibilities (beam-search).
**end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
**end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
**cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size,)``
Log probabilities for the ``is_impossible`` label of the answers.
**mems**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
......@@ -1243,12 +1254,10 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
loss_fct_cls = nn.BCEWithLogitsLoss()
cls_loss = loss_fct_cls(cls_logits, is_impossible)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is
# comparable to start_loss and end_loss
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
total_loss += cls_loss * 0.5
outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs
else:
outputs = (total_loss, start_logits, end_logits) + outputs
outputs = (total_loss,) + outputs
else:
# during inference, compute the end logits based on beam search
......@@ -1256,8 +1265,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
......@@ -1269,11 +1278,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) # get the representation of START as weighted sum of hidden states
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) # Shape (batch size,): one single `cls_logits` for each sample
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems, (hidden states), (attentions)
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits), mems, (hidden states), (attentions)
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
# or (if labels are provided) (total_loss,)
return outputs
......@@ -38,7 +38,10 @@ class TokenizationTest(unittest.TestCase):
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
create_and_check_tokenizer_commons(self, BertTokenizer, tmpdirname)
input_text = u"UNwant\u00E9d,running"
output_text = u"unwanted, running"
create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname)
tokenizer = BertTokenizer(vocab_file)
......
......@@ -41,7 +41,10 @@ class GPT2TokenizationTest(unittest.TestCase):
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map)
input_text = u"lower newer"
output_text = u"lower<unk>newer"
create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map)
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
text = "lower"
......
......@@ -42,7 +42,10 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname)
input_text = u"lower newer"
output_text = u"lower newer"
create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname)
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
......
......@@ -113,23 +113,24 @@ def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kw
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
text = u"He is very happy, UNwant\u00E9d,running"
tokens = tokenizer.tokenize(text)
tokens = tokenizer.tokenize(input_text)
ids = tokenizer.convert_tokens_to_ids(tokens)
ids_2 = tokenizer.encode(text)
ids_2 = tokenizer.encode(input_text)
tester.assertListEqual(ids, ids_2)
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
text_2 = tokenizer.decode(ids)
tester.assertEqual(text_2, output_text)
tester.assertNotEqual(len(tokens_2), 0)
tester.assertIsInstance(text_2, (str, unicode))
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
......@@ -34,7 +34,10 @@ class TransfoXLTokenizationTest(unittest.TestCase):
with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True)
input_text = u"<unk> UNwanted , running"
output_text = u"<unk> unwanted, running"
create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True)
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
......
......@@ -41,7 +41,10 @@ class XLMTokenizationTest(unittest.TestCase):
with open(merges_file, "w") as fp:
fp.write("\n".join(merges))
create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname)
input_text = u"lower newer"
output_text = u"lower newer"
create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname)
tokenizer = XLMTokenizer(vocab_file, merges_file)
......
......@@ -32,7 +32,10 @@ class XLNetTokenizationTest(unittest.TestCase):
with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname)
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname)
input_text = u"This is a test"
output_text = u"This is a test"
create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname)
tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
......
......@@ -161,10 +161,9 @@ class BertTokenizer(PreTrainedTokenizer):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token)
def _convert_ids_to_string(self, tokens_ids):
"""Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(tokens_ids)
out_string = ''.join(tokens).replace(' ##', '').strip()
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = ' '.join(tokens).replace(' ##', '').strip()
return out_string
def save_vocabulary(self, vocab_path):
......
......@@ -185,9 +185,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.decoder.get(index)
def _convert_ids_to_string(self, tokens_ids):
"""Converts a sequence of ids in a string."""
text = ''.join(tokens_ids)
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
text = ''.join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
......
......@@ -174,9 +174,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
"""Converts an id in a token (BPE) using the vocab."""
return self.decoder.get(index, self.unk_token)
def _convert_ids_to_string(self, tokens_ids):
"""Converts a sequence of ids in a string."""
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = ''.join(tokens).replace('</w>', ' ').strip()
return out_string
def save_vocabulary(self, save_directory):
......
......@@ -229,9 +229,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
else:
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
def _convert_ids_to_string(self, tokens_ids):
"""Converts a sequence of ids in a string."""
out_string = ' '.join(tokens_ids).strip()
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = ' '.join(tokens).strip()
return out_string
def convert_to_tensor(self, symbols):
......
......@@ -361,28 +361,33 @@ class PreTrainedTokenizer(object):
(resp.) a sequence of ids, using the vocabulary.
"""
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
return self.convert_token_to_id_with_added_voc(tokens)
return self._convert_token_to_id_with_added_voc(tokens)
ids = []
for token in tokens:
ids.append(self.convert_token_to_id_with_added_voc(token))
ids.append(self._convert_token_to_id_with_added_voc(token))
if len(ids) > self.max_len:
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.max_len))
return ids
def convert_token_to_id_with_added_voc(self, token):
def _convert_token_to_id_with_added_voc(self, token):
if token in self.added_tokens_encoder:
return self.added_tokens_encoder[token]
return self._convert_token_to_id(token)
def _convert_token_to_id(self, token):
raise NotImplementedError
def encode(self, text):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
same as self.convert_tokens_to_ids(self.tokenize(text)).
"""
return self.convert_tokens_to_ids(self.tokenize(text))
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
""" Converts a single index or a sequence of indices (integers) in a token "
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
......@@ -391,7 +396,10 @@ class PreTrainedTokenizer(object):
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
"""
if isinstance(ids, int):
return self.convert_id_to_token(ids)
if ids in self.added_tokens_decoder:
return self.added_tokens_decoder[ids]
else:
return self._convert_id_to_token(ids)
tokens = []
for index in ids:
if index in self.all_special_ids and skip_special_tokens:
......@@ -402,34 +410,26 @@ class PreTrainedTokenizer(object):
tokens.append(self._convert_id_to_token(index))
return tokens
def _convert_id_to_token(self, index):
raise NotImplementedError
def encode(self, text):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
same as self.convert_tokens_to_ids(self.tokenize(text)).
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string.
The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
but we often want to remove sub-word tokenization artifacts at the same time.
"""
return self.convert_tokens_to_ids(self.tokenize(text))
return ' '.join(self.convert_ids_to_tokens(tokens))
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
"""
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
text = self._convert_ids_to_string(filtered_tokens)
text = self.convert_tokens_to_string(filtered_tokens)
if clean_up_tokenization_spaces:
text = clean_up_tokenization(text)
return text
def _convert_ids_to_string(self, tokens_ids):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary.
roughtly same as ' '.join(self.convert_ids_to_tokens(token_ids)).
"""
return ' '.join(self.convert_ids_to_tokens(tokens_ids))
@property
def special_tokens_map(self):
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their
......
......@@ -202,9 +202,9 @@ class XLMTokenizer(PreTrainedTokenizer):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.decoder.get(index, self.unk_token)
def _convert_ids_to_string(self, tokens_ids):
"""Converts a sequence of ids in a string."""
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip()
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = ''.join(tokens).replace('</w>', ' ').strip()
return out_string
def save_vocabulary(self, save_directory):
......
......@@ -170,9 +170,9 @@ class XLNetTokenizer(PreTrainedTokenizer):
token = token.decode('utf-8')
return token
def _convert_ids_to_string(self, tokens_ids):
"""Converts a sequence of ids in a string."""
out_string = ''.join(tokens_ids).replace(SPIECE_UNDERLINE, ' ')
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
return out_string
def save_vocabulary(self, save_directory):
......@@ -184,6 +184,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
copyfile(self.vocab_file, out_vocab_file)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment