Commit 15d8b126 authored by thomwolf's avatar thomwolf
Browse files

update tokenizer - update squad example for xlnet

parent 3b469cb4
...@@ -242,7 +242,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -242,7 +242,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
# Load data features from cache or dataset file # Load data features from cache or dataset file
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format( cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
'dev' if evaluate else 'train', 'dev' if evaluate else 'train',
list(filter(None, args.model_name.split('/'))).pop(), list(filter(None, args.model_name_or_path.split('/'))).pop(),
str(args.max_seq_length), str(args.max_seq_length),
str(task))) str(task)))
if os.path.exists(cached_features_file): if os.path.exists(cached_features_file):
...@@ -282,8 +282,10 @@ def main(): ...@@ -282,8 +282,10 @@ def main():
## Required parameters ## Required parameters
parser.add_argument("--data_dir", default=None, type=str, required=True, parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.") help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--model_name", default=None, type=str, required=True, parser.add_argument("--model_type", default=None, type=str, required=True,
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS)) help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--task_name", default=None, type=str, required=True, parser.add_argument("--task_name", default=None, type=str, required=True,
help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
parser.add_argument("--output_dir", default=None, type=str, required=True, parser.add_argument("--output_dir", default=None, type=str, required=True,
...@@ -400,15 +402,11 @@ def main(): ...@@ -400,15 +402,11 @@ def main():
if args.local_rank not in [-1, 0]: if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = "" args.model_type = args.model_type.lower()
for key in MODEL_CLASSES:
if key in args.model_name.lower():
args.model_type = key # take the first match in model types
break
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name, num_labels=num_labels, finetuning_task=args.task_name) config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case) tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config) model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
if args.local_rank == 0: if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
......
...@@ -213,7 +213,6 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -213,7 +213,6 @@ def evaluate(args, model, tokenizer, prefix=""):
inputs.update({'cls_index': batch[4], inputs.update({'cls_index': batch[4],
'p_mask': batch[5]}) 'p_mask': batch[5]})
outputs = model(**inputs) outputs = model(**inputs)
batch_start_logits, batch_end_logits = outputs[:2]
for i, example_index in enumerate(example_indices): for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()] eval_feature = features[example_index.item()]
...@@ -242,7 +241,8 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -242,7 +241,8 @@ def evaluate(args, model, tokenizer, prefix=""):
write_predictions_extended(examples, features, all_results, args.n_best_size, write_predictions_extended(examples, features, all_results, args.n_best_size,
args.max_answer_length, output_prediction_file, args.max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.predict_file, output_nbest_file, output_null_log_odds_file, args.predict_file,
args.start_n_top, args.end_n_top, args.version_2_with_negative) model.config.start_n_top, model.config.end_n_top,
args.version_2_with_negative, tokenizer, args.verbose_logging)
else: else:
write_predictions(examples, features, all_results, args.n_best_size, write_predictions(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file, args.max_answer_length, args.do_lower_case, output_prediction_file,
...@@ -262,7 +262,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -262,7 +262,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
input_file = args.predict_file if evaluate else args.train_file input_file = args.predict_file if evaluate else args.train_file
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format( cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
'dev' if evaluate else 'train', 'dev' if evaluate else 'train',
list(filter(None, args.model_name.split('/'))).pop(), list(filter(None, args.model_name_or_path.split('/'))).pop(),
str(args.max_seq_length))) str(args.max_seq_length)))
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
...@@ -312,8 +312,10 @@ def main(): ...@@ -312,8 +312,10 @@ def main():
help="SQuAD json for training. E.g., train-v1.1.json") help="SQuAD json for training. E.g., train-v1.1.json")
parser.add_argument("--predict_file", default=None, type=str, required=True, parser.add_argument("--predict_file", default=None, type=str, required=True,
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
parser.add_argument("--model_name", default=None, type=str, required=True, parser.add_argument("--model_type", default=None, type=str, required=True,
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS)) help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--output_dir", default=None, type=str, required=True, parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model checkpoints and predictions will be written.") help="The output directory where the model checkpoints and predictions will be written.")
...@@ -438,15 +440,11 @@ def main(): ...@@ -438,15 +440,11 @@ def main():
if args.local_rank not in [-1, 0]: if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = "" args.model_type = args.model_type.lower()
for key in MODEL_CLASSES:
if key in args.model_name.lower():
args.model_type = key # take the first match in model types
break
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name) config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case) tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config) model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
if args.local_rank == 0: if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
......
...@@ -60,8 +60,9 @@ class ExamplesTests(unittest.TestCase): ...@@ -60,8 +60,9 @@ class ExamplesTests(unittest.TestCase):
"--warmup_steps=2", "--warmup_steps=2",
"--overwrite_output_dir", "--overwrite_output_dir",
"--seed=42"] "--seed=42"]
model_name = "--model_name=bert-base-uncased" model_type, model_name = ("--model_type=bert",
with patch.object(sys, 'argv', testargs + [model_name]): "--model_name_or_path=bert-base-uncased")
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
result = run_glue.main() result = run_glue.main()
for value in result.values(): for value in result.values():
self.assertGreaterEqual(value, 0.75) self.assertGreaterEqual(value, 0.75)
...@@ -85,8 +86,9 @@ class ExamplesTests(unittest.TestCase): ...@@ -85,8 +86,9 @@ class ExamplesTests(unittest.TestCase):
"--per_gpu_eval_batch_size=1", "--per_gpu_eval_batch_size=1",
"--overwrite_output_dir", "--overwrite_output_dir",
"--seed=42"] "--seed=42"]
model_name = "--model_name=bert-base-uncased" model_type, model_name = ("--model_type=bert",
with patch.object(sys, 'argv', testargs + [model_name]): "--model_name_or_path=bert-base-uncased")
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
result = run_squad.main() result = run_squad.main()
self.assertGreaterEqual(result['f1'], 30) self.assertGreaterEqual(result['f1'], 30)
self.assertGreaterEqual(result['exact'], 30) self.assertGreaterEqual(result['exact'], 30)
......
...@@ -87,6 +87,7 @@ class InputFeatures(object): ...@@ -87,6 +87,7 @@ class InputFeatures(object):
segment_ids, segment_ids,
cls_index, cls_index,
p_mask, p_mask,
paragraph_len,
start_position=None, start_position=None,
end_position=None, end_position=None,
is_impossible=None): is_impossible=None):
...@@ -101,6 +102,7 @@ class InputFeatures(object): ...@@ -101,6 +102,7 @@ class InputFeatures(object):
self.segment_ids = segment_ids self.segment_ids = segment_ids
self.cls_index = cls_index self.cls_index = cls_index
self.p_mask = p_mask self.p_mask = p_mask
self.paragraph_len = paragraph_len
self.start_position = start_position self.start_position = start_position
self.end_position = end_position self.end_position = end_position
self.is_impossible = is_impossible self.is_impossible = is_impossible
...@@ -292,6 +294,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -292,6 +294,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
tokens.append(all_doc_tokens[split_token_index]) tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(sequence_b_segment_id) segment_ids.append(sequence_b_segment_id)
p_mask.append(0) p_mask.append(0)
paragraph_len = doc_span.length
# SEP token # SEP token
tokens.append(sep_token) tokens.append(sep_token)
...@@ -385,6 +388,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -385,6 +388,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
segment_ids=segment_ids, segment_ids=segment_ids,
cls_index=cls_index, cls_index=cls_index,
p_mask=p_mask, p_mask=p_mask,
paragraph_len=paragraph_len,
start_position=start_position, start_position=start_position,
end_position=end_position, end_position=end_position,
is_impossible=span_is_impossible)) is_impossible=span_is_impossible))
...@@ -673,8 +677,9 @@ RawResultExtended = collections.namedtuple("RawResultExtended", ...@@ -673,8 +677,9 @@ RawResultExtended = collections.namedtuple("RawResultExtended",
def write_predictions_extended(all_examples, all_features, all_results, n_best_size, def write_predictions_extended(all_examples, all_features, all_results, n_best_size,
max_answer_length, output_prediction_file, max_answer_length, output_prediction_file,
output_nbest_file, output_nbest_file,
output_null_log_odds_file, orig_data, output_null_log_odds_file, orig_data_file,
start_n_top, end_n_top, version_2_with_negative): start_n_top, end_n_top, version_2_with_negative,
tokenizer, verbose_logging):
""" XLNet write prediction logic (more complex than Bert's). """ XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed. Write final predictions to the json file and log-odds of null if needed.
...@@ -764,13 +769,30 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s ...@@ -764,13 +769,30 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
break break
feature = features[pred.feature_index] feature = features[pred.feature_index]
tok_start_to_orig_index = feature.tok_start_to_orig_index # XLNet un-tokenizer
tok_end_to_orig_index = feature.tok_end_to_orig_index # Let's keep it simple for now and see if we need all this later.
start_orig_pos = tok_start_to_orig_index[pred.start_index] #
end_orig_pos = tok_end_to_orig_index[pred.end_index] # tok_start_to_orig_index = feature.tok_start_to_orig_index
# tok_end_to_orig_index = feature.tok_end_to_orig_index
paragraph_text = example.paragraph_text # start_orig_pos = tok_start_to_orig_index[pred.start_index]
final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip() # end_orig_pos = tok_end_to_orig_index[pred.end_index]
# paragraph_text = example.paragraph_text
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
# Previously used Bert untokenizer
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
# Clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
verbose_logging)
if final_text in seen_predictions: if final_text in seen_predictions:
continue continue
...@@ -829,6 +851,9 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s ...@@ -829,6 +851,9 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
with open(output_null_log_odds_file, "w") as writer: with open(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n") writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
with open(orig_data_file, "r", encoding='utf-8') as reader:
orig_data = json.load(reader)["data"]
qid_to_has_ans = make_qid_to_has_ans(orig_data) qid_to_has_ans = make_qid_to_has_ans(orig_data)
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
......
...@@ -528,9 +528,9 @@ class PoolerEndLogits(nn.Module): ...@@ -528,9 +528,9 @@ class PoolerEndLogits(nn.Module):
Mask of invalid position such as query and special symbols (PAD, SEP, CLS) Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
1.0 means token should be masked. 1.0 means token should be masked.
""" """
slen, hsz = hidden_states.shape[-2:]
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
if start_positions is not None: if start_positions is not None:
slen, hsz = hidden_states.shape[-2:]
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
...@@ -571,7 +571,7 @@ class PoolerAnswerClass(nn.Module): ...@@ -571,7 +571,7 @@ class PoolerAnswerClass(nn.Module):
no dependency on end_feature so that we can obtain one single `cls_logits` no dependency on end_feature so that we can obtain one single `cls_logits`
for each sample for each sample
""" """
slen, hsz = hidden_states.shape[-2:] hsz = hidden_states.shape[-1]
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
if start_positions is not None: if start_positions is not None:
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
...@@ -614,12 +614,21 @@ class SQuADHead(nn.Module): ...@@ -614,12 +614,21 @@ class SQuADHead(nn.Module):
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
**last_hidden_state**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) `torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` **start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
Sequence of hidden-states at the last layer of the model. ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
**mems**: Log probabilities for the top config.start_n_top start token possibilities (beam-search).
list of ``torch.FloatTensor`` (one for each layer): **start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model ``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. Indices for the top config.start_n_top start token possibilities (beam-search).
**end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
**end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
**cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size,)``
Log probabilities for the ``is_impossible`` label of the answers.
""" """
def __init__(self, config): def __init__(self, config):
super(SQuADHead, self).__init__() super(SQuADHead, self).__init__()
...@@ -667,8 +676,8 @@ class SQuADHead(nn.Module): ...@@ -667,8 +676,8 @@ class SQuADHead(nn.Module):
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top) start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz) start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz) hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
......
...@@ -1167,12 +1167,23 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1167,12 +1167,23 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
1.0 means token should be masked. 0.0 mean token is not masked. 1.0 means token should be masked. 0.0 mean token is not masked.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` **start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
Span-start scores (before SoftMax). ``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)`` Log probabilities for the top config.start_n_top start token possibilities (beam-search).
Span-end scores (before SoftMax). **start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
Indices for the top config.start_n_top start token possibilities (beam-search).
**end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
**end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
**cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
``torch.FloatTensor`` of shape ``(batch_size,)``
Log probabilities for the ``is_impossible`` label of the answers.
**mems**: **mems**:
list of ``torch.FloatTensor`` (one for each layer): list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
...@@ -1243,12 +1254,10 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1243,12 +1254,10 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
loss_fct_cls = nn.BCEWithLogitsLoss() loss_fct_cls = nn.BCEWithLogitsLoss()
cls_loss = loss_fct_cls(cls_logits, is_impossible) cls_loss = loss_fct_cls(cls_logits, is_impossible)
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
# comparable to start_loss and end_loss
total_loss += cls_loss * 0.5 total_loss += cls_loss * 0.5
outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs
else: outputs = (total_loss,) + outputs
outputs = (total_loss, start_logits, end_logits) + outputs
else: else:
# during inference, compute the end logits based on beam search # during inference, compute the end logits based on beam search
...@@ -1256,8 +1265,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1256,8 +1265,8 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen) start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top) start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz) start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz) hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
...@@ -1269,11 +1278,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1269,11 +1278,11 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) # get the representation of START as weighted sum of hidden states
cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) # Shape (batch size,): one single `cls_logits` for each sample
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems, (hidden states), (attentions) # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits), mems, (hidden states), (attentions) # or (if labels are provided) (total_loss,)
return outputs return outputs
...@@ -38,7 +38,10 @@ class TokenizationTest(unittest.TestCase): ...@@ -38,7 +38,10 @@ class TokenizationTest(unittest.TestCase):
with open(vocab_file, "w", encoding='utf-8') as vocab_writer: with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
create_and_check_tokenizer_commons(self, BertTokenizer, tmpdirname) input_text = u"UNwant\u00E9d,running"
output_text = u"unwanted, running"
create_and_check_tokenizer_commons(self, input_text, output_text, BertTokenizer, tmpdirname)
tokenizer = BertTokenizer(vocab_file) tokenizer = BertTokenizer(vocab_file)
......
...@@ -41,7 +41,10 @@ class GPT2TokenizationTest(unittest.TestCase): ...@@ -41,7 +41,10 @@ class GPT2TokenizationTest(unittest.TestCase):
with open(merges_file, "w") as fp: with open(merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
create_and_check_tokenizer_commons(self, GPT2Tokenizer, tmpdirname, **special_tokens_map) input_text = u"lower newer"
output_text = u"lower<unk>newer"
create_and_check_tokenizer_commons(self, input_text, output_text, GPT2Tokenizer, tmpdirname, **special_tokens_map)
tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map) tokenizer = GPT2Tokenizer(vocab_file, merges_file, **special_tokens_map)
text = "lower" text = "lower"
......
...@@ -42,7 +42,10 @@ class OpenAIGPTTokenizationTest(unittest.TestCase): ...@@ -42,7 +42,10 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
with open(merges_file, "w") as fp: with open(merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, tmpdirname) input_text = u"lower newer"
output_text = u"lower newer"
create_and_check_tokenizer_commons(self, input_text, output_text, OpenAIGPTTokenizer, tmpdirname)
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file) tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file)
......
...@@ -113,23 +113,24 @@ def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kw ...@@ -113,23 +113,24 @@ def create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kw
tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token)) tester.assertEqual(tokens[-2], tokenizer.convert_tokens_to_ids(tokenizer.pad_token))
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs): def create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs) tokenizer = tokenizer_class.from_pretrained(*inputs, **kwargs)
text = u"He is very happy, UNwant\u00E9d,running" tokens = tokenizer.tokenize(input_text)
tokens = tokenizer.tokenize(text)
ids = tokenizer.convert_tokens_to_ids(tokens) ids = tokenizer.convert_tokens_to_ids(tokens)
ids_2 = tokenizer.encode(text) ids_2 = tokenizer.encode(input_text)
tester.assertListEqual(ids, ids_2) tester.assertListEqual(ids, ids_2)
tokens_2 = tokenizer.convert_ids_to_tokens(ids) tokens_2 = tokenizer.convert_ids_to_tokens(ids)
text_2 = tokenizer.decode(ids) text_2 = tokenizer.decode(ids)
tester.assertEqual(text_2, output_text)
tester.assertNotEqual(len(tokens_2), 0) tester.assertNotEqual(len(tokens_2), 0)
tester.assertIsInstance(text_2, (str, unicode)) tester.assertIsInstance(text_2, (str, unicode))
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs): def create_and_check_tokenizer_commons(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs):
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs) create_and_check_required_methods_tokenizer(tester, input_text, output_text, tokenizer_class, *inputs, **kwargs)
create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs) create_and_check_add_tokens_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs) create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs) create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
...@@ -34,7 +34,10 @@ class TransfoXLTokenizationTest(unittest.TestCase): ...@@ -34,7 +34,10 @@ class TransfoXLTokenizationTest(unittest.TestCase):
with open(vocab_file, "w", encoding='utf-8') as vocab_writer: with open(vocab_file, "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, tmpdirname, lower_case=True) input_text = u"<unk> UNwanted , running"
output_text = u"<unk> unwanted, running"
create_and_check_tokenizer_commons(self, input_text, output_text, TransfoXLTokenizer, tmpdirname, lower_case=True)
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True) tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
......
...@@ -41,7 +41,10 @@ class XLMTokenizationTest(unittest.TestCase): ...@@ -41,7 +41,10 @@ class XLMTokenizationTest(unittest.TestCase):
with open(merges_file, "w") as fp: with open(merges_file, "w") as fp:
fp.write("\n".join(merges)) fp.write("\n".join(merges))
create_and_check_tokenizer_commons(self, XLMTokenizer, tmpdirname) input_text = u"lower newer"
output_text = u"lower newer"
create_and_check_tokenizer_commons(self, input_text, output_text, XLMTokenizer, tmpdirname)
tokenizer = XLMTokenizer(vocab_file, merges_file) tokenizer = XLMTokenizer(vocab_file, merges_file)
......
...@@ -32,7 +32,10 @@ class XLNetTokenizationTest(unittest.TestCase): ...@@ -32,7 +32,10 @@ class XLNetTokenizationTest(unittest.TestCase):
with TemporaryDirectory() as tmpdirname: with TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname) tokenizer.save_pretrained(tmpdirname)
create_and_check_tokenizer_commons(self, XLNetTokenizer, tmpdirname) input_text = u"This is a test"
output_text = u"This is a test"
create_and_check_tokenizer_commons(self, input_text, output_text, XLNetTokenizer, tmpdirname)
tokens = tokenizer.tokenize(u'This is a test') tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est']) self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
......
...@@ -161,10 +161,9 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -161,10 +161,9 @@ class BertTokenizer(PreTrainedTokenizer):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token) return self.ids_to_tokens.get(index, self.unk_token)
def _convert_ids_to_string(self, tokens_ids): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of ids in a string.""" """ Converts a sequence of tokens (string) in a single string. """
tokens = self.convert_ids_to_tokens(tokens_ids) out_string = ' '.join(tokens).replace(' ##', '').strip()
out_string = ''.join(tokens).replace(' ##', '').strip()
return out_string return out_string
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
......
...@@ -185,9 +185,9 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -185,9 +185,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.decoder.get(index) return self.decoder.get(index)
def _convert_ids_to_string(self, tokens_ids): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of ids in a string.""" """ Converts a sequence of tokens (string) in a single string. """
text = ''.join(tokens_ids) text = ''.join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text return text
......
...@@ -174,9 +174,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -174,9 +174,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
"""Converts an id in a token (BPE) using the vocab.""" """Converts an id in a token (BPE) using the vocab."""
return self.decoder.get(index, self.unk_token) return self.decoder.get(index, self.unk_token)
def _convert_ids_to_string(self, tokens_ids): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of ids in a string.""" """ Converts a sequence of tokens (string) in a single string. """
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip() out_string = ''.join(tokens).replace('</w>', ' ').strip()
return out_string return out_string
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
......
...@@ -229,9 +229,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -229,9 +229,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
else: else:
raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement') raise ValueError('Token not in vocabulary and no <unk> token in vocabulary for replacement')
def _convert_ids_to_string(self, tokens_ids): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of ids in a string.""" """ Converts a sequence of tokens (string) in a single string. """
out_string = ' '.join(tokens_ids).strip() out_string = ' '.join(tokens).strip()
return out_string return out_string
def convert_to_tensor(self, symbols): def convert_to_tensor(self, symbols):
......
...@@ -361,28 +361,33 @@ class PreTrainedTokenizer(object): ...@@ -361,28 +361,33 @@ class PreTrainedTokenizer(object):
(resp.) a sequence of ids, using the vocabulary. (resp.) a sequence of ids, using the vocabulary.
""" """
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)): if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
return self.convert_token_to_id_with_added_voc(tokens) return self._convert_token_to_id_with_added_voc(tokens)
ids = [] ids = []
for token in tokens: for token in tokens:
ids.append(self.convert_token_to_id_with_added_voc(token)) ids.append(self._convert_token_to_id_with_added_voc(token))
if len(ids) > self.max_len: if len(ids) > self.max_len:
logger.warning("Token indices sequence length is longer than the specified maximum sequence length " logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
"for this model ({} > {}). Running this sequence through the model will result in " "for this model ({} > {}). Running this sequence through the model will result in "
"indexing errors".format(len(ids), self.max_len)) "indexing errors".format(len(ids), self.max_len))
return ids return ids
def _convert_token_to_id_with_added_voc(self, token):
def convert_token_to_id_with_added_voc(self, token):
if token in self.added_tokens_encoder: if token in self.added_tokens_encoder:
return self.added_tokens_encoder[token] return self.added_tokens_encoder[token]
return self._convert_token_to_id(token) return self._convert_token_to_id(token)
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
raise NotImplementedError raise NotImplementedError
def encode(self, text):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
same as self.convert_tokens_to_ids(self.tokenize(text)).
"""
return self.convert_tokens_to_ids(self.tokenize(text))
def convert_ids_to_tokens(self, ids, skip_special_tokens=False): def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
""" Converts a single index or a sequence of indices (integers) in a token " """ Converts a single index or a sequence of indices (integers) in a token "
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens. (resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens.
...@@ -391,7 +396,10 @@ class PreTrainedTokenizer(object): ...@@ -391,7 +396,10 @@ class PreTrainedTokenizer(object):
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
""" """
if isinstance(ids, int): if isinstance(ids, int):
return self.convert_id_to_token(ids) if ids in self.added_tokens_decoder:
return self.added_tokens_decoder[ids]
else:
return self._convert_id_to_token(ids)
tokens = [] tokens = []
for index in ids: for index in ids:
if index in self.all_special_ids and skip_special_tokens: if index in self.all_special_ids and skip_special_tokens:
...@@ -402,34 +410,26 @@ class PreTrainedTokenizer(object): ...@@ -402,34 +410,26 @@ class PreTrainedTokenizer(object):
tokens.append(self._convert_id_to_token(index)) tokens.append(self._convert_id_to_token(index))
return tokens return tokens
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
raise NotImplementedError raise NotImplementedError
def convert_tokens_to_string(self, tokens):
def encode(self, text): """ Converts a sequence of tokens (string) in a single string.
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. The most simple way to do it is ' '.join(self.convert_ids_to_tokens(token_ids))
same as self.convert_tokens_to_ids(self.tokenize(text)). but we often want to remove sub-word tokenization artifacts at the same time.
""" """
return self.convert_tokens_to_ids(self.tokenize(text)) return ' '.join(self.convert_ids_to_tokens(tokens))
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True): def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary """ Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces. with options to remove special tokens and clean up tokenization spaces.
""" """
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens) filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
text = self._convert_ids_to_string(filtered_tokens) text = self.convert_tokens_to_string(filtered_tokens)
if clean_up_tokenization_spaces: if clean_up_tokenization_spaces:
text = clean_up_tokenization(text) text = clean_up_tokenization(text)
return text return text
def _convert_ids_to_string(self, tokens_ids):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary.
roughtly same as ' '.join(self.convert_ids_to_tokens(token_ids)).
"""
return ' '.join(self.convert_ids_to_tokens(tokens_ids))
@property @property
def special_tokens_map(self): def special_tokens_map(self):
""" A dictionary mapping special token class attribute (cls_token, unk_token...) to their """ A dictionary mapping special token class attribute (cls_token, unk_token...) to their
......
...@@ -202,9 +202,9 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -202,9 +202,9 @@ class XLMTokenizer(PreTrainedTokenizer):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.decoder.get(index, self.unk_token) return self.decoder.get(index, self.unk_token)
def _convert_ids_to_string(self, tokens_ids): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of ids in a string.""" """ Converts a sequence of tokens (string) in a single string. """
out_string = ''.join(tokens_ids).replace('</w>', ' ').strip() out_string = ''.join(tokens).replace('</w>', ' ').strip()
return out_string return out_string
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
......
...@@ -170,9 +170,9 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -170,9 +170,9 @@ class XLNetTokenizer(PreTrainedTokenizer):
token = token.decode('utf-8') token = token.decode('utf-8')
return token return token
def _convert_ids_to_string(self, tokens_ids): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of ids in a string.""" """Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens_ids).replace(SPIECE_UNDERLINE, ' ') out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
return out_string return out_string
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
...@@ -184,6 +184,7 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -184,6 +184,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
return return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
copyfile(self.vocab_file, out_vocab_file) if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,) return (out_vocab_file,)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment