"...lm-evaluation-harness.git" did not exist on "6db83ddc2708192677a4e17c6336c0d92be4bf0c"
Commit fa84ae26 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Reformat source code with black.

This is the result of:

    $ black --line-length 119 examples templates transformers utils hubconf.py setup.py

There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.

This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
parent 63e3827c
...@@ -24,8 +24,7 @@ import glob ...@@ -24,8 +24,7 @@ import glob
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
try: try:
...@@ -35,19 +34,32 @@ except: ...@@ -35,19 +34,32 @@ except:
from tqdm import tqdm, trange from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME, BertConfig, from transformers import (
BertForQuestionAnswering, BertTokenizer, WEIGHTS_NAME,
XLMConfig, XLMForQuestionAnswering, BertConfig,
XLMTokenizer, XLNetConfig, BertForQuestionAnswering,
XLNetForQuestionAnswering, BertTokenizer,
XLNetTokenizer, XLMConfig,
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) XLMForQuestionAnswering,
XLMTokenizer,
XLNetConfig,
XLNetForQuestionAnswering,
XLNetTokenizer,
DistilBertConfig,
DistilBertForQuestionAnswering,
DistilBertTokenizer,
)
from transformers import AdamW, get_linear_schedule_with_warmup from transformers import AdamW, get_linear_schedule_with_warmup
from utils_squad import (read_squad_examples, convert_examples_to_features, from utils_squad import (
RawResult, write_predictions, read_squad_examples,
RawResultExtended, write_predictions_extended) convert_examples_to_features,
RawResult,
write_predictions,
RawResultExtended,
write_predictions_extended,
)
# The follwing import is the official SQuAD evaluation script (2.0). # The follwing import is the official SQuAD evaluation script (2.0).
# You can remove it from the dependencies if you are using this script outside of the library # You can remove it from the dependencies if you are using this script outside of the library
...@@ -56,16 +68,18 @@ from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad ...@@ -56,16 +68,18 @@ from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \ ALL_MODELS = sum(
for conf in (BertConfig, XLNetConfig, XLMConfig)), ()) (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()
)
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer), "bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer), "xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer), "xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) "distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
} }
def set_seed(args): def set_seed(args):
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -73,9 +87,11 @@ def set_seed(args): ...@@ -73,9 +87,11 @@ def set_seed(args):
if args.n_gpu > 0: if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
def to_list(tensor): def to_list(tensor):
return tensor.detach().cpu().tolist() return tensor.detach().cpu().tolist()
def train(args, train_dataset, model, tokenizer): def train(args, train_dataset, model, tokenizer):
""" Train the model """ """ Train the model """
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
...@@ -92,13 +108,18 @@ def train(args, train_dataset, model, tokenizer): ...@@ -92,13 +108,18 @@ def train(args, train_dataset, model, tokenizer):
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay) # Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight'] no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
] "weight_decay": args.weight_decay,
},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if args.fp16: if args.fp16:
try: try:
from apex import amp from apex import amp
...@@ -112,17 +133,21 @@ def train(args, train_dataset, model, tokenizer): ...@@ -112,17 +133,21 @@ def train(args, train_dataset, model, tokenizer):
# Distributed training (should be after apex fp16 initialization) # Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], model = torch.nn.parallel.DistributedDataParallel(
output_device=args.local_rank, model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
find_unused_parameters=True) )
# Train! # Train!
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", logger.info(
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) " Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size
* args.gradient_accumulation_steps
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total) logger.info(" Total optimization steps = %d", t_total)
...@@ -136,20 +161,21 @@ def train(args, train_dataset, model, tokenizer): ...@@ -136,20 +161,21 @@ def train(args, train_dataset, model, tokenizer):
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
model.train() model.train()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0], inputs = {
'attention_mask': batch[1], "input_ids": batch[0],
'start_positions': batch[3], "attention_mask": batch[1],
'end_positions': batch[4]} "start_positions": batch[3],
if args.model_type != 'distilbert': "end_positions": batch[4],
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] }
if args.model_type in ['xlnet', 'xlm']: if args.model_type != "distilbert":
inputs.update({'cls_index': batch[5], inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2]
'p_mask': batch[6]}) if args.model_type in ["xlnet", "xlm"]:
inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
outputs = model(**inputs) outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc) loss = outputs[0] # model outputs are always tuple in transformers (see doc)
if args.n_gpu > 1: if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps loss = loss / args.gradient_accumulation_steps
...@@ -173,22 +199,26 @@ def train(args, train_dataset, model, tokenizer): ...@@ -173,22 +199,26 @@ def train(args, train_dataset, model, tokenizer):
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics # Log metrics
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well if (
args.local_rank == -1 and args.evaluate_during_training
): # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer) results = evaluate(args, model, tokenizer)
for key, value in results.items(): for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step) tb_writer.add_scalar("eval_{}".format(key), value, global_step)
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
logging_loss = tr_loss logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint # Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir) model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin')) torch.save(args, os.path.join(output_dir, "training_args.bin"))
logger.info("Saving model checkpoint to %s", output_dir) logger.info("Saving model checkpoint to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps: if args.max_steps > 0 and global_step > args.max_steps:
...@@ -224,32 +254,31 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -224,32 +254,31 @@ def evaluate(args, model, tokenizer, prefix=""):
model.eval() model.eval()
batch = tuple(t.to(args.device) for t in batch) batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad(): with torch.no_grad():
inputs = {'input_ids': batch[0], inputs = {"input_ids": batch[0], "attention_mask": batch[1]}
'attention_mask': batch[1] if args.model_type != "distilbert":
} inputs["token_type_ids"] = None if args.model_type == "xlm" else batch[2] # XLM don't use segment_ids
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
example_indices = batch[3] example_indices = batch[3]
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ["xlnet", "xlm"]:
inputs.update({'cls_index': batch[4], inputs.update({"cls_index": batch[4], "p_mask": batch[5]})
'p_mask': batch[5]})
outputs = model(**inputs) outputs = model(**inputs)
for i, example_index in enumerate(example_indices): for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()] eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id) unique_id = int(eval_feature.unique_id)
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ["xlnet", "xlm"]:
# XLNet uses a more complex post-processing procedure # XLNet uses a more complex post-processing procedure
result = RawResultExtended(unique_id = unique_id, result = RawResultExtended(
start_top_log_probs = to_list(outputs[0][i]), unique_id=unique_id,
start_top_index = to_list(outputs[1][i]), start_top_log_probs=to_list(outputs[0][i]),
end_top_log_probs = to_list(outputs[2][i]), start_top_index=to_list(outputs[1][i]),
end_top_index = to_list(outputs[3][i]), end_top_log_probs=to_list(outputs[2][i]),
cls_logits = to_list(outputs[4][i])) end_top_index=to_list(outputs[3][i]),
cls_logits=to_list(outputs[4][i]),
)
else: else:
result = RawResult(unique_id = unique_id, result = RawResult(
start_logits = to_list(outputs[0][i]), unique_id=unique_id, start_logits=to_list(outputs[0][i]), end_logits=to_list(outputs[1][i])
end_logits = to_list(outputs[1][i])) )
all_results.append(result) all_results.append(result)
# Compute predictions # Compute predictions
...@@ -260,23 +289,44 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -260,23 +289,44 @@ def evaluate(args, model, tokenizer, prefix=""):
else: else:
output_null_log_odds_file = None output_null_log_odds_file = None
if args.model_type in ['xlnet', 'xlm']: if args.model_type in ["xlnet", "xlm"]:
# XLNet uses a more complex post-processing procedure # XLNet uses a more complex post-processing procedure
write_predictions_extended(examples, features, all_results, args.n_best_size, write_predictions_extended(
args.max_answer_length, output_prediction_file, examples,
output_nbest_file, output_null_log_odds_file, args.predict_file, features,
model.config.start_n_top, model.config.end_n_top, all_results,
args.version_2_with_negative, tokenizer, args.verbose_logging) args.n_best_size,
args.max_answer_length,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
args.predict_file,
model.config.start_n_top,
model.config.end_n_top,
args.version_2_with_negative,
tokenizer,
args.verbose_logging,
)
else: else:
write_predictions(examples, features, all_results, args.n_best_size, write_predictions(
args.max_answer_length, args.do_lower_case, output_prediction_file, examples,
output_nbest_file, output_null_log_odds_file, args.verbose_logging, features,
args.version_2_with_negative, args.null_score_diff_threshold) all_results,
args.n_best_size,
args.max_answer_length,
args.do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
args.verbose_logging,
args.version_2_with_negative,
args.null_score_diff_threshold,
)
# Evaluate with the official SQuAD script # Evaluate with the official SQuAD script
evaluate_options = EVAL_OPTS(data_file=args.predict_file, evaluate_options = EVAL_OPTS(
pred_file=output_prediction_file, data_file=args.predict_file, pred_file=output_prediction_file, na_prob_file=output_null_log_odds_file
na_prob_file=output_null_log_odds_file) )
results = evaluate_on_squad(evaluate_options) results = evaluate_on_squad(evaluate_options)
return results return results
...@@ -287,24 +337,30 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -287,24 +337,30 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
# Load data features from cache or dataset file # Load data features from cache or dataset file
input_file = args.predict_file if evaluate else args.train_file input_file = args.predict_file if evaluate else args.train_file
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format( cached_features_file = os.path.join(
'dev' if evaluate else 'train', os.path.dirname(input_file),
list(filter(None, args.model_name_or_path.split('/'))).pop(), "cached_{}_{}_{}".format(
str(args.max_seq_length))) "dev" if evaluate else "train",
list(filter(None, args.model_name_or_path.split("/"))).pop(),
str(args.max_seq_length),
),
)
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples: if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file) features = torch.load(cached_features_file)
else: else:
logger.info("Creating features from dataset file at %s", input_file) logger.info("Creating features from dataset file at %s", input_file)
examples = read_squad_examples(input_file=input_file, examples = read_squad_examples(
is_training=not evaluate, input_file=input_file, is_training=not evaluate, version_2_with_negative=args.version_2_with_negative
version_2_with_negative=args.version_2_with_negative) )
features = convert_examples_to_features(examples=examples, features = convert_examples_to_features(
tokenizer=tokenizer, examples=examples,
max_seq_length=args.max_seq_length, tokenizer=tokenizer,
doc_stride=args.doc_stride, max_seq_length=args.max_seq_length,
max_query_length=args.max_query_length, doc_stride=args.doc_stride,
is_training=not evaluate) max_query_length=args.max_query_length,
is_training=not evaluate,
)
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file) logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file) torch.save(features, cached_features_file)
...@@ -320,14 +376,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal ...@@ -320,14 +376,21 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
if evaluate: if evaluate:
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(
all_example_index, all_cls_index, all_p_mask) all_input_ids, all_input_mask, all_segment_ids, all_example_index, all_cls_index, all_p_mask
)
else: else:
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, dataset = TensorDataset(
all_start_positions, all_end_positions, all_input_ids,
all_cls_index, all_p_mask) all_input_mask,
all_segment_ids,
all_start_positions,
all_end_positions,
all_cls_index,
all_p_mask,
)
if output_examples: if output_examples:
return dataset, examples, features return dataset, examples, features
...@@ -338,109 +401,190 @@ def main(): ...@@ -338,109 +401,190 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--train_file", default=None, type=str, required=True, parser.add_argument(
help="SQuAD json for training. E.g., train-v1.1.json") "--train_file", default=None, type=str, required=True, help="SQuAD json for training. E.g., train-v1.1.json"
parser.add_argument("--predict_file", default=None, type=str, required=True, )
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json") parser.add_argument(
parser.add_argument("--model_type", default=None, type=str, required=True, "--predict_file",
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) default=None,
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, type=str,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) required=True,
parser.add_argument("--output_dir", default=None, type=str, required=True, help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json",
help="The output directory where the model checkpoints and predictions will be written.") )
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model checkpoints and predictions will be written.",
)
## Other parameters ## Other parameters
parser.add_argument("--config_name", default="", type=str, parser.add_argument(
help="Pretrained config name or path if not the same as model_name") "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
parser.add_argument("--tokenizer_name", default="", type=str, )
help="Pretrained tokenizer name or path if not the same as model_name") parser.add_argument(
parser.add_argument("--cache_dir", default="", type=str, "--tokenizer_name",
help="Where do you want to store the pre-trained models downloaded from s3") default="",
type=str,
parser.add_argument('--version_2_with_negative', action='store_true', help="Pretrained tokenizer name or path if not the same as model_name",
help='If true, the SQuAD examples contain some that do not have an answer.') )
parser.add_argument('--null_score_diff_threshold', type=float, default=0.0, parser.add_argument(
help="If null_score - best_non_null is greater than the threshold predict null.") "--cache_dir",
default="",
parser.add_argument("--max_seq_length", default=384, type=int, type=str,
help="The maximum total input sequence length after WordPiece tokenization. Sequences " help="Where do you want to store the pre-trained models downloaded from s3",
"longer than this will be truncated, and sequences shorter than this will be padded.") )
parser.add_argument("--doc_stride", default=128, type=int,
help="When splitting up a long document into chunks, how much stride to take between chunks.") parser.add_argument(
parser.add_argument("--max_query_length", default=64, type=int, "--version_2_with_negative",
help="The maximum number of tokens for the question. Questions longer than this will " action="store_true",
"be truncated to this length.") help="If true, the SQuAD examples contain some that do not have an answer.",
parser.add_argument("--do_train", action='store_true', )
help="Whether to run training.") parser.add_argument(
parser.add_argument("--do_eval", action='store_true', "--null_score_diff_threshold",
help="Whether to run eval on the dev set.") type=float,
parser.add_argument("--evaluate_during_training", action='store_true', default=0.0,
help="Rul evaluation during training at each logging step.") help="If null_score - best_non_null is greater than the threshold predict null.",
parser.add_argument("--do_lower_case", action='store_true', )
help="Set this flag if you are using an uncased model.")
parser.add_argument(
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, "--max_seq_length",
help="Batch size per GPU/CPU for training.") default=384,
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, type=int,
help="Batch size per GPU/CPU for evaluation.") help="The maximum total input sequence length after WordPiece tokenization. Sequences "
parser.add_argument("--learning_rate", default=5e-5, type=float, "longer than this will be truncated, and sequences shorter than this will be padded.",
help="The initial learning rate for Adam.") )
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, parser.add_argument(
help="Number of updates steps to accumulate before performing a backward/update pass.") "--doc_stride",
parser.add_argument("--weight_decay", default=0.0, type=float, default=128,
help="Weight deay if we apply some.") type=int,
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="When splitting up a long document into chunks, how much stride to take between chunks.",
help="Epsilon for Adam optimizer.") )
parser.add_argument("--max_grad_norm", default=1.0, type=float, parser.add_argument(
help="Max gradient norm.") "--max_query_length",
parser.add_argument("--num_train_epochs", default=3.0, type=float, default=64,
help="Total number of training epochs to perform.") type=int,
parser.add_argument("--max_steps", default=-1, type=int, help="The maximum number of tokens for the question. Questions longer than this will "
help="If > 0: set total number of training steps to perform. Override num_train_epochs.") "be truncated to this length.",
parser.add_argument("--warmup_steps", default=0, type=int, )
help="Linear warmup over warmup_steps.") parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--n_best_size", default=20, type=int, parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.") parser.add_argument(
parser.add_argument("--max_answer_length", default=30, type=int, "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
help="The maximum length of an answer that can be generated. This is needed because the start " )
"and end predictions are not conditioned on one another.") parser.add_argument(
parser.add_argument("--verbose_logging", action='store_true', "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
help="If true, all of the warnings related to data processing will be printed. " )
"A number of warnings are expected for a normal SQuAD evaluation.")
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
parser.add_argument('--logging_steps', type=int, default=50, parser.add_argument(
help="Log every X updates steps.") "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
parser.add_argument('--save_steps', type=int, default=50, )
help="Save checkpoint every X updates steps.") parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--eval_all_checkpoints", action='store_true', parser.add_argument(
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") "--gradient_accumulation_steps",
parser.add_argument("--no_cuda", action='store_true', type=int,
help="Whether not to use CUDA when available") default=1,
parser.add_argument('--overwrite_output_dir', action='store_true', help="Number of updates steps to accumulate before performing a backward/update pass.",
help="Overwrite the content of the output directory") )
parser.add_argument('--overwrite_cache', action='store_true', parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight deay if we apply some.")
help="Overwrite the cached training and evaluation sets") parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument('--seed', type=int, default=42, parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
help="random seed for initialization") parser.add_argument(
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
parser.add_argument("--local_rank", type=int, default=-1, )
help="local_rank for distributed training on gpus") parser.add_argument(
parser.add_argument('--fp16', action='store_true', "--max_steps",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") default=-1,
parser.add_argument('--fp16_opt_level', type=str, default='O1', type=int,
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
"See details at https://nvidia.github.io/apex/amp.html") )
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") parser.add_argument(
"--n_best_size",
default=20,
type=int,
help="The total number of n-best predictions to generate in the nbest_predictions.json output file.",
)
parser.add_argument(
"--max_answer_length",
default=30,
type=int,
help="The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another.",
)
parser.add_argument(
"--verbose_logging",
action="store_true",
help="If true, all of the warnings related to data processing will be printed. "
"A number of warnings are expected for a normal SQuAD evaluation.",
)
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
parser.add_argument(
"--eval_all_checkpoints",
action="store_true",
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
)
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
parser.add_argument(
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
)
parser.add_argument(
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
)
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
parser.add_argument(
"--fp16",
action="store_true",
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
)
parser.add_argument(
"--fp16_opt_level",
type=str,
default="O1",
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html",
)
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: if (
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args.output_dir
)
)
# Setup distant debugging if needed # Setup distant debugging if needed
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -452,16 +596,24 @@ def main(): ...@@ -452,16 +596,24 @@ def main():
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank) device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl') torch.distributed.init_process_group(backend="nccl")
args.n_gpu = 1 args.n_gpu = 1
args.device = device args.device = device
# Setup logging # Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) datefmt="%m/%d/%Y %H:%M:%S",
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) )
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank,
device,
args.n_gpu,
bool(args.local_rank != -1),
args.fp16,
)
# Set seed # Set seed
set_seed(args) set_seed(args)
...@@ -472,15 +624,21 @@ def main(): ...@@ -472,15 +624,21 @@ def main():
args.model_type = args.model_type.lower() args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, config = config_class.from_pretrained(
cache_dir=args.cache_dir if args.cache_dir else None) args.config_name if args.config_name else args.model_name_or_path,
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None,
do_lower_case=args.do_lower_case, )
cache_dir=args.cache_dir if args.cache_dir else None) tokenizer = tokenizer_class.from_pretrained(
model = model_class.from_pretrained(args.model_name_or_path, args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
from_tf=bool('.ckpt' in args.model_name_or_path), do_lower_case=args.do_lower_case,
config=config, cache_dir=args.cache_dir if args.cache_dir else None,
cache_dir=args.cache_dir if args.cache_dir else None) )
model = model_class.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir if args.cache_dir else None,
)
if args.local_rank == 0: if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
...@@ -495,7 +653,8 @@ def main(): ...@@ -495,7 +653,8 @@ def main():
if args.fp16: if args.fp16:
try: try:
import apex import apex
apex.amp.register_half_function(torch, 'einsum')
apex.amp.register_half_function(torch, "einsum")
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
...@@ -505,7 +664,6 @@ def main(): ...@@ -505,7 +664,6 @@ def main():
global_step, tr_loss = train(args, train_dataset, model, tokenizer) global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Save the trained model and the tokenizer # Save the trained model and the tokenizer
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed # Create output directory if needed
...@@ -515,39 +673,42 @@ def main(): ...@@ -515,39 +673,42 @@ def main():
logger.info("Saving model checkpoint to %s", args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`. # Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir) model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model # Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned # Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir) model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device) model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results = {} results = {}
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) checkpoints = list(
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
)
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
# Reload the model # Reload the model
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
# Evaluate # Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items()) result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
results.update(result) results.update(result)
logger.info("Results: {}".format(results)) logger.info("Results: {}".format(results))
......
# coding=utf-8 # coding=utf-8
# Copyright 2018 XXX. All rights reserved. # Copyright 2018 XXX. All rights reserved.
# #
...@@ -37,14 +36,16 @@ class SquadExample(object): ...@@ -37,14 +36,16 @@ class SquadExample(object):
For examples without an answer, the start and end position are -1. For examples without an answer, the start and end position are -1.
""" """
def __init__(self, def __init__(
qas_id, self,
question_text, qas_id,
doc_tokens, question_text,
orig_answer_text=None, doc_tokens,
start_position=None, orig_answer_text=None,
end_position=None, start_position=None,
is_impossible=None): end_position=None,
is_impossible=None,
):
self.qas_id = qas_id self.qas_id = qas_id
self.question_text = question_text self.question_text = question_text
self.doc_tokens = doc_tokens self.doc_tokens = doc_tokens
...@@ -59,8 +60,7 @@ class SquadExample(object): ...@@ -59,8 +60,7 @@ class SquadExample(object):
def __repr__(self): def __repr__(self):
s = "" s = ""
s += "qas_id: %s" % (self.qas_id) s += "qas_id: %s" % (self.qas_id)
s += ", question_text: %s" % ( s += ", question_text: %s" % (self.question_text)
self.question_text)
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
if self.start_position: if self.start_position:
s += ", start_position: %d" % (self.start_position) s += ", start_position: %d" % (self.start_position)
...@@ -74,22 +74,24 @@ class SquadExample(object): ...@@ -74,22 +74,24 @@ class SquadExample(object):
class InputFeatures(object): class InputFeatures(object):
"""A single set of features of data.""" """A single set of features of data."""
def __init__(self, def __init__(
unique_id, self,
example_index, unique_id,
doc_span_index, example_index,
tokens, doc_span_index,
token_to_orig_map, tokens,
token_is_max_context, token_to_orig_map,
input_ids, token_is_max_context,
input_mask, input_ids,
segment_ids, input_mask,
cls_index, segment_ids,
p_mask, cls_index,
paragraph_len, p_mask,
start_position=None, paragraph_len,
end_position=None, start_position=None,
is_impossible=None): end_position=None,
is_impossible=None,
):
self.unique_id = unique_id self.unique_id = unique_id
self.example_index = example_index self.example_index = example_index
self.doc_span_index = doc_span_index self.doc_span_index = doc_span_index
...@@ -109,7 +111,7 @@ class InputFeatures(object): ...@@ -109,7 +111,7 @@ class InputFeatures(object):
def read_squad_examples(input_file, is_training, version_2_with_negative): def read_squad_examples(input_file, is_training, version_2_with_negative):
"""Read a SQuAD json file into a list of SquadExample.""" """Read a SQuAD json file into a list of SquadExample."""
with open(input_file, "r", encoding='utf-8') as reader: with open(input_file, "r", encoding="utf-8") as reader:
input_data = json.load(reader)["data"] input_data = json.load(reader)["data"]
def is_whitespace(c): def is_whitespace(c):
...@@ -146,8 +148,7 @@ def read_squad_examples(input_file, is_training, version_2_with_negative): ...@@ -146,8 +148,7 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
if version_2_with_negative: if version_2_with_negative:
is_impossible = qa["is_impossible"] is_impossible = qa["is_impossible"]
if (len(qa["answers"]) != 1) and (not is_impossible): if (len(qa["answers"]) != 1) and (not is_impossible):
raise ValueError( raise ValueError("For training, each question should have exactly 1 answer.")
"For training, each question should have exactly 1 answer.")
if not is_impossible: if not is_impossible:
answer = qa["answers"][0] answer = qa["answers"][0]
orig_answer_text = answer["text"] orig_answer_text = answer["text"]
...@@ -161,12 +162,10 @@ def read_squad_examples(input_file, is_training, version_2_with_negative): ...@@ -161,12 +162,10 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
# #
# Note that this means for training mode, every example is NOT # Note that this means for training mode, every example is NOT
# guaranteed to be preserved. # guaranteed to be preserved.
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)]) actual_text = " ".join(doc_tokens[start_position : (end_position + 1)])
cleaned_answer_text = " ".join( cleaned_answer_text = " ".join(whitespace_tokenize(orig_answer_text))
whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1: if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'", logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text)
actual_text, cleaned_answer_text)
continue continue
else: else:
start_position = -1 start_position = -1
...@@ -180,18 +179,29 @@ def read_squad_examples(input_file, is_training, version_2_with_negative): ...@@ -180,18 +179,29 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
orig_answer_text=orig_answer_text, orig_answer_text=orig_answer_text,
start_position=start_position, start_position=start_position,
end_position=end_position, end_position=end_position,
is_impossible=is_impossible) is_impossible=is_impossible,
)
examples.append(example) examples.append(example)
return examples return examples
def convert_examples_to_features(examples, tokenizer, max_seq_length, def convert_examples_to_features(
doc_stride, max_query_length, is_training, examples,
cls_token_at_end=False, tokenizer,
cls_token='[CLS]', sep_token='[SEP]', pad_token=0, max_seq_length,
sequence_a_segment_id=0, sequence_b_segment_id=1, doc_stride,
cls_token_segment_id=0, pad_token_segment_id=0, max_query_length,
mask_padding_with_zero=True): is_training,
cls_token_at_end=False,
cls_token="[CLS]",
sep_token="[SEP]",
pad_token=0,
sequence_a_segment_id=0,
sequence_b_segment_id=1,
cls_token_segment_id=0,
pad_token_segment_id=0,
mask_padding_with_zero=True,
):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
unique_id = 1000000000 unique_id = 1000000000
...@@ -232,8 +242,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -232,8 +242,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
else: else:
tok_end_position = len(all_doc_tokens) - 1 tok_end_position = len(all_doc_tokens) - 1
(tok_start_position, tok_end_position) = _improve_answer_span( (tok_start_position, tok_end_position) = _improve_answer_span(
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.orig_answer_text
example.orig_answer_text) )
# The -3 accounts for [CLS], [SEP] and [SEP] # The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3 max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
...@@ -241,8 +251,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -241,8 +251,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
# We can have documents that are longer than the maximum sequence length. # We can have documents that are longer than the maximum sequence length.
# To deal with this we do a sliding window approach, where we take chunks # To deal with this we do a sliding window approach, where we take chunks
# of the up to our max length with a stride of `doc_stride`. # of the up to our max length with a stride of `doc_stride`.
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name _DocSpan = collections.namedtuple("DocSpan", ["start", "length"]) # pylint: disable=invalid-name
"DocSpan", ["start", "length"])
doc_spans = [] doc_spans = []
start_offset = 0 start_offset = 0
while start_offset < len(all_doc_tokens): while start_offset < len(all_doc_tokens):
...@@ -287,8 +296,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -287,8 +296,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
split_token_index = doc_span.start + i split_token_index = doc_span.start + i
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
is_max_context = _check_is_max_context(doc_spans, doc_span_index, is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index)
split_token_index)
token_is_max_context[len(tokens)] = is_max_context token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index]) tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(sequence_b_segment_id) segment_ids.append(sequence_b_segment_id)
...@@ -333,8 +341,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -333,8 +341,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_start = doc_span.start doc_start = doc_span.start
doc_end = doc_span.start + doc_span.length - 1 doc_end = doc_span.start + doc_span.length - 1
out_of_span = False out_of_span = False
if not (tok_start_position >= doc_start and if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
tok_end_position <= doc_end):
out_of_span = True out_of_span = True
if out_of_span: if out_of_span:
start_position = 0 start_position = 0
...@@ -355,24 +362,23 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -355,24 +362,23 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
logger.info("example_index: %s" % (example_index)) logger.info("example_index: %s" % (example_index))
logger.info("doc_span_index: %s" % (doc_span_index)) logger.info("doc_span_index: %s" % (doc_span_index))
logger.info("tokens: %s" % " ".join(tokens)) logger.info("tokens: %s" % " ".join(tokens))
logger.info("token_to_orig_map: %s" % " ".join([
"%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()]))
logger.info("token_is_max_context: %s" % " ".join([
"%d:%s" % (x, y) for (x, y) in token_is_max_context.items()
]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info( logger.info(
"input_mask: %s" % " ".join([str(x) for x in input_mask])) "token_to_orig_map: %s" % " ".join(["%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()])
)
logger.info( logger.info(
"segment_ids: %s" % " ".join([str(x) for x in segment_ids])) "token_is_max_context: %s"
% " ".join(["%d:%s" % (x, y) for (x, y) in token_is_max_context.items()])
)
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
if is_training and span_is_impossible: if is_training and span_is_impossible:
logger.info("impossible example") logger.info("impossible example")
if is_training and not span_is_impossible: if is_training and not span_is_impossible:
answer_text = " ".join(tokens[start_position:(end_position + 1)]) answer_text = " ".join(tokens[start_position : (end_position + 1)])
logger.info("start_position: %d" % (start_position)) logger.info("start_position: %d" % (start_position))
logger.info("end_position: %d" % (end_position)) logger.info("end_position: %d" % (end_position))
logger.info( logger.info("answer: %s" % (answer_text))
"answer: %s" % (answer_text))
features.append( features.append(
InputFeatures( InputFeatures(
...@@ -390,14 +396,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -390,14 +396,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
paragraph_len=paragraph_len, paragraph_len=paragraph_len,
start_position=start_position, start_position=start_position,
end_position=end_position, end_position=end_position,
is_impossible=span_is_impossible)) is_impossible=span_is_impossible,
)
)
unique_id += 1 unique_id += 1
return features return features
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text):
orig_answer_text):
"""Returns tokenized answer spans that better match the annotated answer.""" """Returns tokenized answer spans that better match the annotated answer."""
# The SQuAD annotations are character based. We first project them to # The SQuAD annotations are character based. We first project them to
...@@ -426,7 +433,7 @@ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, ...@@ -426,7 +433,7 @@ def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
for new_start in range(input_start, input_end + 1): for new_start in range(input_start, input_end + 1):
for new_end in range(input_end, new_start - 1, -1): for new_end in range(input_end, new_start - 1, -1):
text_span = " ".join(doc_tokens[new_start:(new_end + 1)]) text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
if text_span == tok_answer_text: if text_span == tok_answer_text:
return (new_start, new_end) return (new_start, new_end)
...@@ -470,13 +477,23 @@ def _check_is_max_context(doc_spans, cur_span_index, position): ...@@ -470,13 +477,23 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return cur_span_index == best_span_index return cur_span_index == best_span_index
RawResult = collections.namedtuple("RawResult", RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"])
["unique_id", "start_logits", "end_logits"])
def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file, def write_predictions(
output_nbest_file, output_null_log_odds_file, verbose_logging, all_examples,
version_2_with_negative, null_score_diff_threshold): all_features,
all_results,
n_best_size,
max_answer_length,
do_lower_case,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
verbose_logging,
version_2_with_negative,
null_score_diff_threshold,
):
"""Write final predictions to the json file and log-odds of null if needed.""" """Write final predictions to the json file and log-odds of null if needed."""
logger.info("Writing predictions to: %s" % (output_prediction_file)) logger.info("Writing predictions to: %s" % (output_prediction_file))
logger.info("Writing nbest to: %s" % (output_nbest_file)) logger.info("Writing nbest to: %s" % (output_nbest_file))
...@@ -490,8 +507,8 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -490,8 +507,8 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
unique_id_to_result[result.unique_id] = result unique_id_to_result[result.unique_id] = result
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction", "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"]
["feature_index", "start_index", "end_index", "start_logit", "end_logit"]) )
all_predictions = collections.OrderedDict() all_predictions = collections.OrderedDict()
all_nbest_json = collections.OrderedDict() all_nbest_json = collections.OrderedDict()
...@@ -544,7 +561,9 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -544,7 +561,9 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
start_index=start_index, start_index=start_index,
end_index=end_index, end_index=end_index,
start_logit=result.start_logits[start_index], start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index])) end_logit=result.end_logits[end_index],
)
)
if version_2_with_negative: if version_2_with_negative:
prelim_predictions.append( prelim_predictions.append(
_PrelimPrediction( _PrelimPrediction(
...@@ -552,14 +571,14 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -552,14 +571,14 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
start_index=0, start_index=0,
end_index=0, end_index=0,
start_logit=null_start_logit, start_logit=null_start_logit,
end_logit=null_end_logit)) end_logit=null_end_logit,
prelim_predictions = sorted( )
prelim_predictions, )
key=lambda x: (x.start_logit + x.end_logit), prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
reverse=True)
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_logit", "end_logit"]) "NbestPrediction", ["text", "start_logit", "end_logit"]
)
seen_predictions = {} seen_predictions = {}
nbest = [] nbest = []
...@@ -568,10 +587,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -568,10 +587,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
break break
feature = features[pred.feature_index] feature = features[pred.feature_index]
if pred.start_index > 0: # this is a non-null prediction if pred.start_index > 0: # this is a non-null prediction
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index] orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
tok_text = " ".join(tok_tokens) tok_text = " ".join(tok_tokens)
# De-tokenize WordPieces that have been split off. # De-tokenize WordPieces that have been split off.
...@@ -592,31 +611,21 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -592,31 +611,21 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
final_text = "" final_text = ""
seen_predictions[final_text] = True seen_predictions[final_text] = True
nbest.append( nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit))
_NbestPrediction(
text=final_text,
start_logit=pred.start_logit,
end_logit=pred.end_logit))
# if we didn't include the empty option in the n-best, include it # if we didn't include the empty option in the n-best, include it
if version_2_with_negative: if version_2_with_negative:
if "" not in seen_predictions: if "" not in seen_predictions:
nbest.append( nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit))
_NbestPrediction(
text="",
start_logit=null_start_logit,
end_logit=null_end_logit))
# In very rare edge cases we could only have single null prediction. # In very rare edge cases we could only have single null prediction.
# So we just create a nonce prediction in this case to avoid failure. # So we just create a nonce prediction in this case to avoid failure.
if len(nbest)==1: if len(nbest) == 1:
nbest.insert(0, nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
# In very rare edge cases we could have no valid predictions. So we # In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure. # just create a nonce prediction in this case to avoid failure.
if not nbest: if not nbest:
nbest.append( nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0))
assert len(nbest) >= 1 assert len(nbest) >= 1
...@@ -645,8 +654,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -645,8 +654,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
all_predictions[example.qas_id] = nbest_json[0]["text"] all_predictions[example.qas_id] = nbest_json[0]["text"]
else: else:
# predict "" iff the null score - the score of best non-null > threshold # predict "" iff the null score - the score of best non-null > threshold
score_diff = score_null - best_non_null_entry.start_logit - ( score_diff = score_null - best_non_null_entry.start_logit - (best_non_null_entry.end_logit)
best_non_null_entry.end_logit)
scores_diff_json[example.qas_id] = score_diff scores_diff_json[example.qas_id] = score_diff
if score_diff > null_score_diff_threshold: if score_diff > null_score_diff_threshold:
all_predictions[example.qas_id] = "" all_predictions[example.qas_id] = ""
...@@ -668,29 +676,40 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -668,29 +676,40 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
# For XLNet (and XLM which uses the same head) # For XLNet (and XLM which uses the same head)
RawResultExtended = collections.namedtuple("RawResultExtended", RawResultExtended = collections.namedtuple(
["unique_id", "start_top_log_probs", "start_top_index", "RawResultExtended",
"end_top_log_probs", "end_top_index", "cls_logits"]) ["unique_id", "start_top_log_probs", "start_top_index", "end_top_log_probs", "end_top_index", "cls_logits"],
)
def write_predictions_extended(all_examples, all_features, all_results, n_best_size,
max_answer_length, output_prediction_file, def write_predictions_extended(
output_nbest_file, all_examples,
output_null_log_odds_file, orig_data_file, all_features,
start_n_top, end_n_top, version_2_with_negative, all_results,
tokenizer, verbose_logging): n_best_size,
max_answer_length,
output_prediction_file,
output_nbest_file,
output_null_log_odds_file,
orig_data_file,
start_n_top,
end_n_top,
version_2_with_negative,
tokenizer,
verbose_logging,
):
""" XLNet write prediction logic (more complex than Bert's). """ XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed. Write final predictions to the json file and log-odds of null if needed.
Requires utils_squad_evaluate.py Requires utils_squad_evaluate.py
""" """
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction", "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
["feature_index", "start_index", "end_index", )
"start_log_prob", "end_log_prob"])
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"]) "NbestPrediction", ["text", "start_log_prob", "end_log_prob"]
)
logger.info("Writing predictions to: %s", output_prediction_file) logger.info("Writing predictions to: %s", output_prediction_file)
# logger.info("Writing nbest to: %s" % (output_nbest_file)) # logger.info("Writing nbest to: %s" % (output_nbest_file))
...@@ -754,12 +773,13 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s ...@@ -754,12 +773,13 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
start_index=start_index, start_index=start_index,
end_index=end_index, end_index=end_index,
start_log_prob=start_log_prob, start_log_prob=start_log_prob,
end_log_prob=end_log_prob)) end_log_prob=end_log_prob,
)
)
prelim_predictions = sorted( prelim_predictions = sorted(
prelim_predictions, prelim_predictions, key=lambda x: (x.start_log_prob + x.end_log_prob), reverse=True
key=lambda x: (x.start_log_prob + x.end_log_prob), )
reverse=True)
seen_predictions = {} seen_predictions = {}
nbest = [] nbest = []
...@@ -770,7 +790,7 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s ...@@ -770,7 +790,7 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
# XLNet un-tokenizer # XLNet un-tokenizer
# Let's keep it simple for now and see if we need all this later. # Let's keep it simple for now and see if we need all this later.
# #
# tok_start_to_orig_index = feature.tok_start_to_orig_index # tok_start_to_orig_index = feature.tok_start_to_orig_index
# tok_end_to_orig_index = feature.tok_end_to_orig_index # tok_end_to_orig_index = feature.tok_end_to_orig_index
# start_orig_pos = tok_start_to_orig_index[pred.start_index] # start_orig_pos = tok_start_to_orig_index[pred.start_index]
...@@ -779,10 +799,10 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s ...@@ -779,10 +799,10 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip() # final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
# Previously used Bert untokenizer # Previously used Bert untokenizer
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)] tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)]
orig_doc_start = feature.token_to_orig_map[pred.start_index] orig_doc_start = feature.token_to_orig_map[pred.start_index]
orig_doc_end = feature.token_to_orig_map[pred.end_index] orig_doc_end = feature.token_to_orig_map[pred.end_index]
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)] orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)]
tok_text = tokenizer.convert_tokens_to_string(tok_tokens) tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
# Clean whitespace # Clean whitespace
...@@ -790,8 +810,7 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s ...@@ -790,8 +810,7 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
tok_text = " ".join(tok_text.split()) tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens) orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case, final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case, verbose_logging)
verbose_logging)
if final_text in seen_predictions: if final_text in seen_predictions:
continue continue
...@@ -799,17 +818,13 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s ...@@ -799,17 +818,13 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
seen_predictions[final_text] = True seen_predictions[final_text] = True
nbest.append( nbest.append(
_NbestPrediction( _NbestPrediction(text=final_text, start_log_prob=pred.start_log_prob, end_log_prob=pred.end_log_prob)
text=final_text, )
start_log_prob=pred.start_log_prob,
end_log_prob=pred.end_log_prob))
# In very rare edge cases we could have no valid predictions. So we # In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure. # just create a nonce prediction in this case to avoid failure.
if not nbest: if not nbest:
nbest.append( nbest.append(_NbestPrediction(text="", start_log_prob=-1e6, end_log_prob=-1e6))
_NbestPrediction(text="", start_log_prob=-1e6,
end_log_prob=-1e6))
total_scores = [] total_scores = []
best_non_null_entry = None best_non_null_entry = None
...@@ -850,7 +865,7 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s ...@@ -850,7 +865,7 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
with open(output_null_log_odds_file, "w") as writer: with open(output_null_log_odds_file, "w") as writer:
writer.write(json.dumps(scores_diff_json, indent=4) + "\n") writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
with open(orig_data_file, "r", encoding='utf-8') as reader: with open(orig_data_file, "r", encoding="utf-8") as reader:
orig_data = json.load(reader)["data"] orig_data = json.load(reader)["data"]
qid_to_has_ans = make_qid_to_has_ans(orig_data) qid_to_has_ans = make_qid_to_has_ans(orig_data)
...@@ -914,8 +929,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -914,8 +929,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
start_position = tok_text.find(pred_text) start_position = tok_text.find(pred_text)
if start_position == -1: if start_position == -1:
if verbose_logging: if verbose_logging:
logger.info( logger.info("Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
return orig_text return orig_text
end_position = start_position + len(pred_text) - 1 end_position = start_position + len(pred_text) - 1
...@@ -924,8 +938,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -924,8 +938,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
if len(orig_ns_text) != len(tok_ns_text): if len(orig_ns_text) != len(tok_ns_text):
if verbose_logging: if verbose_logging:
logger.info("Length not equal after stripping spaces: '%s' vs '%s'", logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text)
orig_ns_text, tok_ns_text)
return orig_text return orig_text
# We then project the characters in `pred_text` back to `orig_text` using # We then project the characters in `pred_text` back to `orig_text` using
...@@ -956,7 +969,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -956,7 +969,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
logger.info("Couldn't map end position") logger.info("Couldn't map end position")
return orig_text return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)] output_text = orig_text[orig_start_position : (orig_end_position + 1)]
return output_text return output_text
......
...@@ -27,8 +27,8 @@ from .configuration_utils import PretrainedConfig ...@@ -27,8 +27,8 @@ from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
XXX_PRETRAINED_CONFIG_ARCHIVE_MAP = { XXX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-config.json", "xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-config.json",
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-config.json", "xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-config.json",
} }
...@@ -63,24 +63,26 @@ class XxxConfig(PretrainedConfig): ...@@ -63,24 +63,26 @@ class XxxConfig(PretrainedConfig):
""" """
pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=50257, self,
n_positions=1024, vocab_size=50257,
n_ctx=1024, n_positions=1024,
n_embd=768, n_ctx=1024,
n_layer=12, n_embd=768,
n_head=12, n_layer=12,
resid_pdrop=0.1, n_head=12,
embd_pdrop=0.1, resid_pdrop=0.1,
attn_pdrop=0.1, embd_pdrop=0.1,
layer_norm_epsilon=1e-5, attn_pdrop=0.1,
initializer_range=0.02, layer_norm_epsilon=1e-5,
summary_type='cls_index', initializer_range=0.02,
summary_use_proj=True, summary_type="cls_index",
summary_activation=None, summary_use_proj=True,
summary_proj_to_labels=True, summary_activation=None,
summary_first_dropout=0.1, summary_proj_to_labels=True,
**kwargs): summary_first_dropout=0.1,
**kwargs
):
super(XxxConfig, self).__init__(**kwargs) super(XxxConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.n_ctx = n_ctx self.n_ctx = n_ctx
......
...@@ -24,8 +24,10 @@ import torch ...@@ -24,8 +24,10 @@ import torch
from transformers import XxxConfig, XxxForPreTraining, load_tf_weights_in_xxx from transformers import XxxConfig, XxxForPreTraining, load_tf_weights_in_xxx
import logging import logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
# Initialise PyTorch model # Initialise PyTorch model
config = XxxConfig.from_json_file(config_file) config = XxxConfig.from_json_file(config_file)
...@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du ...@@ -43,23 +45,19 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## Required parameters ## Required parameters
parser.add_argument("--tf_checkpoint_path", parser.add_argument(
default = None, "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
type = str, )
required = True, parser.add_argument(
help = "Path to the TensorFlow checkpoint path.") "--config_file",
parser.add_argument("--config_file", default=None,
default = None, type=str,
type = str, required=True,
required = True, help="The config json file corresponding to the pre-trained model. \n"
help = "The config json file corresponding to the pre-trained model. \n" "This specifies the model architecture.",
"This specifies the model architecture.") )
parser.add_argument("--pytorch_dump_path", parser.add_argument(
default = None, "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
type = str, )
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args() args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
args.config_file,
args.pytorch_dump_path)
...@@ -44,8 +44,8 @@ logger = logging.getLogger(__name__) ...@@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
# for the pretrained weights provided with the models # for the pretrained weights provided with the models
#################################################### ####################################################
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = { TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-tf_model.h5", "xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-tf_model.h5",
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-tf_model.h5", "xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-tf_model.h5",
} }
#################################################### ####################################################
...@@ -69,9 +69,9 @@ TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -69,9 +69,9 @@ TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
class TFXxxLayer(tf.keras.layers.Layer): class TFXxxLayer(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super(TFXxxLayer, self).__init__(**kwargs) super(TFXxxLayer, self).__init__(**kwargs)
self.attention = TFXxxAttention(config, name='attention') self.attention = TFXxxAttention(config, name="attention")
self.intermediate = TFXxxIntermediate(config, name='intermediate') self.intermediate = TFXxxIntermediate(config, name="intermediate")
self.transformer_output = TFXxxOutput(config, name='output') self.transformer_output = TFXxxOutput(config, name="output")
def call(self, inputs, training=False): def call(self, inputs, training=False):
hidden_states, attention_mask, head_mask = inputs hidden_states, attention_mask, head_mask = inputs
...@@ -98,7 +98,9 @@ class TFXxxMainLayer(tf.keras.layers.Layer): ...@@ -98,7 +98,9 @@ class TFXxxMainLayer(tf.keras.layers.Layer):
def _prune_heads(self, heads_to_prune): def _prune_heads(self, heads_to_prune):
raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models
def call(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False): def call(
self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False
):
# We allow three types of multi-inputs: # We allow three types of multi-inputs:
# - traditional keyword arguments in the call method # - traditional keyword arguments in the call method
# - all the arguments provided as a dict in the first positional argument of call # - all the arguments provided as a dict in the first positional argument of call
...@@ -113,11 +115,11 @@ class TFXxxMainLayer(tf.keras.layers.Layer): ...@@ -113,11 +115,11 @@ class TFXxxMainLayer(tf.keras.layers.Layer):
head_mask = inputs[4] if len(inputs) > 4 else head_mask head_mask = inputs[4] if len(inputs) > 4 else head_mask
assert len(inputs) <= 5, "Too many inputs." assert len(inputs) <= 5, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get('input_ids') input_ids = inputs.get("input_ids")
attention_mask = inputs.get('attention_mask', attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get('token_type_ids', token_type_ids) token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get('position_ids', position_ids) position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get('head_mask', head_mask) head_mask = inputs.get("head_mask", head_mask)
assert len(inputs) <= 5, "Too many inputs." assert len(inputs) <= 5, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -175,6 +177,7 @@ class TFXxxPreTrainedModel(TFPreTrainedModel): ...@@ -175,6 +177,7 @@ class TFXxxPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = XxxConfig config_class = XxxConfig
pretrained_model_archive_map = TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer" base_model_prefix = "transformer"
...@@ -263,8 +266,12 @@ XXX_INPUTS_DOCSTRING = r""" ...@@ -263,8 +266,12 @@ XXX_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING) @add_start_docstrings(
"The bare Xxx Model transformer outputing raw hidden-states without any specific head on top.",
XXX_START_DOCSTRING,
XXX_INPUTS_DOCSTRING,
)
class TFXxxModel(TFXxxPreTrainedModel): class TFXxxModel(TFXxxPreTrainedModel):
r""" r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -297,17 +304,19 @@ class TFXxxModel(TFXxxPreTrainedModel): ...@@ -297,17 +304,19 @@ class TFXxxModel(TFXxxPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(TFXxxModel, self).__init__(config, *inputs, **kwargs) super(TFXxxModel, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXxxMainLayer(config, name='transformer') self.transformer = TFXxxMainLayer(config, name="transformer")
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs) outputs = self.transformer(inputs, **kwargs)
return outputs return outputs
@add_start_docstrings("""Xxx Model with a `language modeling` head on top. """, @add_start_docstrings(
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING) """Xxx Model with a `language modeling` head on top. """, XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING
)
class TFXxxForMaskedLM(TFXxxPreTrainedModel): class TFXxxForMaskedLM(TFXxxPreTrainedModel):
r""" r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -333,26 +342,30 @@ class TFXxxForMaskedLM(TFXxxPreTrainedModel): ...@@ -333,26 +342,30 @@ class TFXxxForMaskedLM(TFXxxPreTrainedModel):
prediction_scores = outputs[0] prediction_scores = outputs[0]
""" """
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(TFXxxForMaskedLM, self).__init__(config, *inputs, **kwargs) super(TFXxxForMaskedLM, self).__init__(config, *inputs, **kwargs)
self.transformer = TFXxxMainLayer(config, name='transformer') self.transformer = TFXxxMainLayer(config, name="transformer")
self.mlm = TFXxxMLMHead(config, self.transformer.embeddings, name='mlm') self.mlm = TFXxxMLMHead(config, self.transformer.embeddings, name="mlm")
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs) outputs = self.transformer(inputs, **kwargs)
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.mlm(sequence_output, training=kwargs.get('training', False)) prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
return outputs # prediction_scores, (hidden_states), (attentions) return outputs # prediction_scores, (hidden_states), (attentions)
@add_start_docstrings("""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of @add_start_docstrings(
"""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """, the pooled output) e.g. for GLUE tasks. """,
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING) XXX_START_DOCSTRING,
XXX_INPUTS_DOCSTRING,
)
class TFXxxForSequenceClassification(TFXxxPreTrainedModel): class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
r""" r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -378,22 +391,23 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel): ...@@ -378,22 +391,23 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
logits = outputs[0] logits = outputs[0]
""" """
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(TFXxxForSequenceClassification, self).__init__(config, *inputs, **kwargs) super(TFXxxForSequenceClassification, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.transformer = TFXxxMainLayer(config, name='transformer') self.transformer = TFXxxMainLayer(config, name="transformer")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, self.classifier = tf.keras.layers.Dense(
kernel_initializer=get_initializer(config.initializer_range), config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
name='classifier') )
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs) outputs = self.transformer(inputs, **kwargs)
pooled_output = outputs[1] pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output, training=kwargs.get('training', False)) pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False))
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
...@@ -401,9 +415,12 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel): ...@@ -401,9 +415,12 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel):
return outputs # logits, (hidden_states), (attentions) return outputs # logits, (hidden_states), (attentions)
@add_start_docstrings("""Xxx Model with a token classification head on top (a linear layer on top of @add_start_docstrings(
"""Xxx Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING) XXX_START_DOCSTRING,
XXX_INPUTS_DOCSTRING,
)
class TFXxxForTokenClassification(TFXxxPreTrainedModel): class TFXxxForTokenClassification(TFXxxPreTrainedModel):
r""" r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -429,22 +446,23 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel): ...@@ -429,22 +446,23 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel):
scores = outputs[0] scores = outputs[0]
""" """
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(TFXxxForTokenClassification, self).__init__(config, *inputs, **kwargs) super(TFXxxForTokenClassification, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.transformer = TFXxxMainLayer(config, name='transformer') self.transformer = TFXxxMainLayer(config, name="transformer")
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
self.classifier = tf.keras.layers.Dense(config.num_labels, self.classifier = tf.keras.layers.Dense(
kernel_initializer=get_initializer(config.initializer_range), config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
name='classifier') )
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs) outputs = self.transformer(inputs, **kwargs)
sequence_output = outputs[0] sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=kwargs.get('training', False)) sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
logits = self.classifier(sequence_output) logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
...@@ -452,9 +470,12 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel): ...@@ -452,9 +470,12 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel):
return outputs # scores, (hidden_states), (attentions) return outputs # scores, (hidden_states), (attentions)
@add_start_docstrings("""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of @add_start_docstrings(
"""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """, the hidden-states output to compute `span start logits` and `span end logits`). """,
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING) XXX_START_DOCSTRING,
XXX_INPUTS_DOCSTRING,
)
class TFXxxForQuestionAnswering(TFXxxPreTrainedModel): class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
r""" r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -482,14 +503,15 @@ class TFXxxForQuestionAnswering(TFXxxPreTrainedModel): ...@@ -482,14 +503,15 @@ class TFXxxForQuestionAnswering(TFXxxPreTrainedModel):
start_scores, end_scores = outputs[:2] start_scores, end_scores = outputs[:2]
""" """
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(TFXxxForQuestionAnswering, self).__init__(config, *inputs, **kwargs) super(TFXxxForQuestionAnswering, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.transformer = TFXxxMainLayer(config, name='transformer') self.transformer = TFXxxMainLayer(config, name="transformer")
self.qa_outputs = tf.keras.layers.Dense(config.num_labels, self.qa_outputs = tf.keras.layers.Dense(
kernel_initializer=get_initializer(config.initializer_range), config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
name='qa_outputs') )
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
outputs = self.transformer(inputs, **kwargs) outputs = self.transformer(inputs, **kwargs)
......
...@@ -44,8 +44,8 @@ logger = logging.getLogger(__name__) ...@@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
# for the pretrained weights provided with the models # for the pretrained weights provided with the models
#################################################### ####################################################
XXX_PRETRAINED_MODEL_ARCHIVE_MAP = { XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-pytorch_model.bin", "xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-pytorch_model.bin",
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-pytorch_model.bin", "xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-pytorch_model.bin",
} }
#################################################### ####################################################
...@@ -60,8 +60,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path): ...@@ -60,8 +60,10 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
except ImportError: except ImportError:
logger.error("Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " logger.error(
"https://www.tensorflow.org/install/ for installation instructions.") "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise raise
tf_path = os.path.abspath(tf_checkpoint_path) tf_path = os.path.abspath(tf_checkpoint_path)
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path)) logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
...@@ -76,7 +78,7 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path): ...@@ -76,7 +78,7 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
arrays.append(array) arrays.append(array)
for name, array in zip(names, arrays): for name, array in zip(names, arrays):
name = name.split('/') name = name.split("/")
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model # which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name): if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
...@@ -84,18 +86,18 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path): ...@@ -84,18 +86,18 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
continue continue
pointer = model pointer = model
for m_name in name: for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name): if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
l = re.split(r'_(\d+)', m_name) l = re.split(r"_(\d+)", m_name)
else: else:
l = [m_name] l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma': if l[0] == "kernel" or l[0] == "gamma":
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, "weight")
elif l[0] == 'output_bias' or l[0] == 'beta': elif l[0] == "output_bias" or l[0] == "beta":
pointer = getattr(pointer, 'bias') pointer = getattr(pointer, "bias")
elif l[0] == 'output_weights': elif l[0] == "output_weights":
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, "weight")
elif l[0] == 'squad': elif l[0] == "squad":
pointer = getattr(pointer, 'classifier') pointer = getattr(pointer, "classifier")
else: else:
try: try:
pointer = getattr(pointer, l[0]) pointer = getattr(pointer, l[0])
...@@ -105,9 +107,9 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path): ...@@ -105,9 +107,9 @@ def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
if len(l) >= 2: if len(l) >= 2:
num = int(l[1]) num = int(l[1])
pointer = pointer[num] pointer = pointer[num]
if m_name[-11:] == '_embeddings': if m_name[-11:] == "_embeddings":
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, "weight")
elif m_name == 'kernel': elif m_name == "kernel":
array = np.transpose(array) array = np.transpose(array)
try: try:
assert pointer.shape == array.shape assert pointer.shape == array.shape
...@@ -147,7 +149,6 @@ class XxxLayer(nn.Module): ...@@ -147,7 +149,6 @@ class XxxLayer(nn.Module):
return outputs return outputs
#################################################### ####################################################
# PreTrainedModel is a sub-class of torch.nn.Module # PreTrainedModel is a sub-class of torch.nn.Module
# which take care of loading and saving pretrained weights # which take care of loading and saving pretrained weights
...@@ -161,6 +162,7 @@ class XxxPreTrainedModel(PreTrainedModel): ...@@ -161,6 +162,7 @@ class XxxPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
config_class = XxxConfig config_class = XxxConfig
pretrained_model_archive_map = XXX_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = XXX_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights = load_tf_weights_in_xxx load_tf_weights = load_tf_weights_in_xxx
...@@ -246,8 +248,12 @@ XXX_INPUTS_DOCSTRING = r""" ...@@ -246,8 +248,12 @@ XXX_INPUTS_DOCSTRING = r"""
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
""" """
@add_start_docstrings("The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING) @add_start_docstrings(
"The bare Xxx Model transformer outputting raw hidden-states without any specific head on top.",
XXX_START_DOCSTRING,
XXX_INPUTS_DOCSTRING,
)
class XxxModel(XxxPreTrainedModel): class XxxModel(XxxPreTrainedModel):
r""" r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
...@@ -277,6 +283,7 @@ class XxxModel(XxxPreTrainedModel): ...@@ -277,6 +283,7 @@ class XxxModel(XxxPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
def __init__(self, config): def __init__(self, config):
super(XxxModel, self).__init__(config) super(XxxModel, self).__init__(config)
...@@ -300,7 +307,15 @@ class XxxModel(XxxPreTrainedModel): ...@@ -300,7 +307,15 @@ class XxxModel(XxxPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads) self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None): def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -329,7 +344,7 @@ class XxxModel(XxxPreTrainedModel): ...@@ -329,7 +344,7 @@ class XxxModel(XxxPreTrainedModel):
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed # Prepare head mask if needed
...@@ -342,14 +357,20 @@ class XxxModel(XxxPreTrainedModel): ...@@ -342,14 +357,20 @@ class XxxModel(XxxPreTrainedModel):
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2: elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer head_mask = (
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
head_mask = head_mask.to(
dtype=next(self.parameters()).dtype
) # switch to fload if need + fp16 compatibility
else: else:
head_mask = [None] * self.config.num_hidden_layers head_mask = [None] * self.config.num_hidden_layers
################################## ##################################
# Replace this with your model code # Replace this with your model code
embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds) embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
)
encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask) encoder_outputs = self.encoder(embedding_output, extended_attention_mask, head_mask=head_mask)
sequence_output = encoder_outputs[0] sequence_output = encoder_outputs[0]
outputs = (sequence_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here outputs = (sequence_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
...@@ -357,8 +378,9 @@ class XxxModel(XxxPreTrainedModel): ...@@ -357,8 +378,9 @@ class XxxModel(XxxPreTrainedModel):
return outputs # sequence_output, (hidden_states), (attentions) return outputs # sequence_output, (hidden_states), (attentions)
@add_start_docstrings("""Xxx Model with a `language modeling` head on top. """, @add_start_docstrings(
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING) """Xxx Model with a `language modeling` head on top. """, XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING
)
class XxxForMaskedLM(XxxPreTrainedModel): class XxxForMaskedLM(XxxPreTrainedModel):
r""" r"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
...@@ -389,6 +411,7 @@ class XxxForMaskedLM(XxxPreTrainedModel): ...@@ -389,6 +411,7 @@ class XxxForMaskedLM(XxxPreTrainedModel):
loss, prediction_scores = outputs[:2] loss, prediction_scores = outputs[:2]
""" """
def __init__(self, config): def __init__(self, config):
super(XxxForMaskedLM, self).__init__(config) super(XxxForMaskedLM, self).__init__(config)
...@@ -400,15 +423,25 @@ class XxxForMaskedLM(XxxPreTrainedModel): ...@@ -400,15 +423,25 @@ class XxxForMaskedLM(XxxPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, def forward(
masked_lm_labels=None): self,
input_ids=None,
outputs = self.transformer(input_ids, attention_mask=None,
attention_mask=attention_mask, token_type_ids=None,
token_type_ids=token_type_ids, position_ids=None,
position_ids=position_ids, head_mask=None,
head_mask=head_mask, inputs_embeds=None,
inputs_embeds=inputs_embeds) masked_lm_labels=None,
):
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0] sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output) prediction_scores = self.cls(sequence_output)
...@@ -422,9 +455,12 @@ class XxxForMaskedLM(XxxPreTrainedModel): ...@@ -422,9 +455,12 @@ class XxxForMaskedLM(XxxPreTrainedModel):
return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
@add_start_docstrings("""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of @add_start_docstrings(
"""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """, the pooled output) e.g. for GLUE tasks. """,
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING) XXX_START_DOCSTRING,
XXX_INPUTS_DOCSTRING,
)
class XxxForSequenceClassification(XxxPreTrainedModel): class XxxForSequenceClassification(XxxPreTrainedModel):
r""" r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
...@@ -456,6 +492,7 @@ class XxxForSequenceClassification(XxxPreTrainedModel): ...@@ -456,6 +492,7 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
loss, logits = outputs[:2] loss, logits = outputs[:2]
""" """
def __init__(self, config): def __init__(self, config):
super(XxxForSequenceClassification, self).__init__(config) super(XxxForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -466,15 +503,25 @@ class XxxForSequenceClassification(XxxPreTrainedModel): ...@@ -466,15 +503,25 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, def forward(
position_ids=None, head_mask=None, inputs_embeds=None, labels=None): self,
input_ids=None,
outputs = self.transformer(input_ids, attention_mask=None,
attention_mask=attention_mask, token_type_ids=None,
token_type_ids=token_type_ids, position_ids=None,
position_ids=position_ids, head_mask=None,
head_mask=head_mask, inputs_embeds=None,
inputs_embeds=inputs_embeds) labels=None,
):
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
pooled_output = outputs[1] pooled_output = outputs[1]
...@@ -496,9 +543,12 @@ class XxxForSequenceClassification(XxxPreTrainedModel): ...@@ -496,9 +543,12 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
return outputs # (loss), logits, (hidden_states), (attentions) return outputs # (loss), logits, (hidden_states), (attentions)
@add_start_docstrings("""Xxx Model with a token classification head on top (a linear layer on top of @add_start_docstrings(
"""Xxx Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING) XXX_START_DOCSTRING,
XXX_INPUTS_DOCSTRING,
)
class XxxForTokenClassification(XxxPreTrainedModel): class XxxForTokenClassification(XxxPreTrainedModel):
r""" r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
...@@ -528,6 +578,7 @@ class XxxForTokenClassification(XxxPreTrainedModel): ...@@ -528,6 +578,7 @@ class XxxForTokenClassification(XxxPreTrainedModel):
loss, scores = outputs[:2] loss, scores = outputs[:2]
""" """
def __init__(self, config): def __init__(self, config):
super(XxxForTokenClassification, self).__init__(config) super(XxxForTokenClassification, self).__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -538,15 +589,25 @@ class XxxForTokenClassification(XxxPreTrainedModel): ...@@ -538,15 +589,25 @@ class XxxForTokenClassification(XxxPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, def forward(
position_ids=None, head_mask=None, inputs_embeds=None, labels=None): self,
input_ids=None,
outputs = self.transformer(input_ids, attention_mask=None,
attention_mask=attention_mask, token_type_ids=None,
token_type_ids=token_type_ids, position_ids=None,
position_ids=position_ids, head_mask=None,
head_mask=head_mask, inputs_embeds=None,
inputs_embeds=inputs_embeds) labels=None,
):
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -569,9 +630,12 @@ class XxxForTokenClassification(XxxPreTrainedModel): ...@@ -569,9 +630,12 @@ class XxxForTokenClassification(XxxPreTrainedModel):
return outputs # (loss), scores, (hidden_states), (attentions) return outputs # (loss), scores, (hidden_states), (attentions)
@add_start_docstrings("""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of @add_start_docstrings(
"""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """, the hidden-states output to compute `span start logits` and `span end logits`). """,
XXX_START_DOCSTRING, XXX_INPUTS_DOCSTRING) XXX_START_DOCSTRING,
XXX_INPUTS_DOCSTRING,
)
class XxxForQuestionAnswering(XxxPreTrainedModel): class XxxForQuestionAnswering(XxxPreTrainedModel):
r""" r"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
...@@ -613,6 +677,7 @@ class XxxForQuestionAnswering(XxxPreTrainedModel): ...@@ -613,6 +677,7 @@ class XxxForQuestionAnswering(XxxPreTrainedModel):
""" """
def __init__(self, config): def __init__(self, config):
super(XxxForQuestionAnswering, self).__init__(config) super(XxxForQuestionAnswering, self).__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
...@@ -622,15 +687,26 @@ class XxxForQuestionAnswering(XxxPreTrainedModel): ...@@ -622,15 +687,26 @@ class XxxForQuestionAnswering(XxxPreTrainedModel):
self.init_weights() self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, def forward(
start_positions=None, end_positions=None): self,
input_ids=None,
outputs = self.transformer(input_ids, attention_mask=None,
attention_mask=attention_mask, token_type_ids=None,
token_type_ids=token_type_ids, position_ids=None,
position_ids=position_ids, head_mask=None,
head_mask=head_mask, inputs_embeds=None,
inputs_embeds=inputs_embeds) start_positions=None,
end_positions=None,
):
outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[0] sequence_output = outputs[0]
......
...@@ -19,7 +19,7 @@ from __future__ import print_function ...@@ -19,7 +19,7 @@ from __future__ import print_function
import unittest import unittest
import sys import sys
from .modeling_tf_common_test import (TFCommonTestCases, ids_tensor) from .modeling_tf_common_test import TFCommonTestCases, ids_tensor
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import CACHE_DIR, require_tf, slow from .utils import CACHE_DIR, require_tf, slow
...@@ -27,46 +27,57 @@ from transformers import XxxConfig, is_tf_available ...@@ -27,46 +27,57 @@ from transformers import XxxConfig, is_tf_available
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
from transformers.modeling_tf_xxx import (TFXxxModel, TFXxxForMaskedLM, from transformers.modeling_tf_xxx import (
TFXxxForSequenceClassification, TFXxxModel,
TFXxxForTokenClassification, TFXxxForMaskedLM,
TFXxxForQuestionAnswering, TFXxxForSequenceClassification,
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP) TFXxxForTokenClassification,
TFXxxForQuestionAnswering,
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP,
)
@require_tf @require_tf
class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester): class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes = (TFXxxModel, TFXxxForMaskedLM, TFXxxForQuestionAnswering, all_model_classes = (
TFXxxForSequenceClassification, (
TFXxxForTokenClassification) if is_tf_available() else () TFXxxModel,
TFXxxForMaskedLM,
TFXxxForQuestionAnswering,
TFXxxForSequenceClassification,
TFXxxForTokenClassification,
)
if is_tf_available()
else ()
)
class TFXxxModelTester(object): class TFXxxModelTester(object):
def __init__(
def __init__(self, self,
parent, parent,
batch_size=13, batch_size=13,
seq_length=7, seq_length=7,
is_training=True, is_training=True,
use_input_mask=True, use_input_mask=True,
use_token_type_ids=True, use_token_type_ids=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
hidden_size=32, hidden_size=32,
num_hidden_layers=5, num_hidden_layers=5,
num_attention_heads=4, num_attention_heads=4,
intermediate_size=37, intermediate_size=37,
hidden_act="gelu", hidden_act="gelu",
hidden_dropout_prob=0.1, hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
max_position_embeddings=512, max_position_embeddings=512,
type_vocab_size=16, type_vocab_size=16,
type_sequence_label_size=2, type_sequence_label_size=2,
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
num_choices=4, num_choices=4,
scope=None, scope=None,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length self.seq_length = seq_length
...@@ -120,15 +131,16 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -120,15 +131,16 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range) initializer_range=self.initializer_range,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_xxx_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_xxx_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFXxxModel(config=config) model = TFXxxModel(config=config)
inputs = {'input_ids': input_ids, inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
'attention_mask': input_mask,
'token_type_ids': token_type_ids}
sequence_output, pooled_output = model(inputs) sequence_output, pooled_output = model(inputs)
inputs = [input_ids, input_mask] inputs = [input_ids, input_mask]
...@@ -141,78 +153,74 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -141,78 +153,74 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
"pooled_output": pooled_output.numpy(), "pooled_output": pooled_output.numpy(),
} }
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].shape), list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
[self.batch_size, self.seq_length, self.hidden_size]) )
self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooled_output"].shape), [self.batch_size, self.hidden_size])
def create_and_check_xxx_for_masked_lm(
def create_and_check_xxx_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFXxxForMaskedLM(config=config) model = TFXxxForMaskedLM(config=config)
inputs = {'input_ids': input_ids, inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
'attention_mask': input_mask, (prediction_scores,) = model(inputs)
'token_type_ids': token_type_ids}
prediction_scores, = model(inputs)
result = { result = {
"prediction_scores": prediction_scores.numpy(), "prediction_scores": prediction_scores.numpy(),
} }
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["prediction_scores"].shape), list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
[self.batch_size, self.seq_length, self.vocab_size]) )
def create_and_check_xxx_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_xxx_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels config.num_labels = self.num_labels
model = TFXxxForSequenceClassification(config=config) model = TFXxxForSequenceClassification(config=config)
inputs = {'input_ids': input_ids, inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
'attention_mask': input_mask, (logits,) = model(inputs)
'token_type_ids': token_type_ids}
logits, = model(inputs)
result = { result = {
"logits": logits.numpy(), "logits": logits.numpy(),
} }
self.parent.assertListEqual( self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_labels])
list(result["logits"].shape),
[self.batch_size, self.num_labels])
def create_and_check_xxx_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_xxx_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels config.num_labels = self.num_labels
model = TFXxxForTokenClassification(config=config) model = TFXxxForTokenClassification(config=config)
inputs = {'input_ids': input_ids, inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
'attention_mask': input_mask, (logits,) = model(inputs)
'token_type_ids': token_type_ids}
logits, = model(inputs)
result = { result = {
"logits": logits.numpy(), "logits": logits.numpy(),
} }
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["logits"].shape), list(result["logits"].shape), [self.batch_size, self.seq_length, self.num_labels]
[self.batch_size, self.seq_length, self.num_labels]) )
def create_and_check_xxx_for_question_answering(
def create_and_check_xxx_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = TFXxxForQuestionAnswering(config=config) model = TFXxxForQuestionAnswering(config=config)
inputs = {'input_ids': input_ids, inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
'attention_mask': input_mask,
'token_type_ids': token_type_ids}
start_logits, end_logits = model(inputs) start_logits, end_logits = model(inputs)
result = { result = {
"start_logits": start_logits.numpy(), "start_logits": start_logits.numpy(),
"end_logits": end_logits.numpy(), "end_logits": end_logits.numpy(),
} }
self.parent.assertListEqual( self.parent.assertListEqual(list(result["start_logits"].shape), [self.batch_size, self.seq_length])
list(result["start_logits"].shape), self.parent.assertListEqual(list(result["end_logits"].shape), [self.batch_size, self.seq_length])
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].shape),
[self.batch_size, self.seq_length])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_mask, (
sequence_labels, token_labels, choice_labels) = config_and_inputs config,
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask} input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict return config, inputs_dict
def setUp(self): def setUp(self):
...@@ -244,9 +252,10 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester): ...@@ -244,9 +252,10 @@ class TFXxxModelTest(TFCommonTestCases.TFCommonModelTester):
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in ['xxx-base-uncased']: for model_name in ["xxx-base-uncased"]:
model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = TFXxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -20,51 +20,60 @@ import unittest ...@@ -20,51 +20,60 @@ import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .modeling_common_test import (CommonTestCases, ids_tensor) from .modeling_common_test import CommonTestCases, ids_tensor
from .configuration_common_test import ConfigTester from .configuration_common_test import ConfigTester
from .utils import CACHE_DIR, require_torch, slow, torch_device from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
from transformers import (XxxConfig, XxxModel, XxxForMaskedLM, from transformers import (
XxxForNextSentencePrediction, XxxForPreTraining, XxxConfig,
XxxForQuestionAnswering, XxxForSequenceClassification, XxxModel,
XxxForTokenClassification, XxxForMultipleChoice) XxxForMaskedLM,
XxxForNextSentencePrediction,
XxxForPreTraining,
XxxForQuestionAnswering,
XxxForSequenceClassification,
XxxForTokenClassification,
XxxForMultipleChoice,
)
from transformers.modeling_xxx import XXX_PRETRAINED_MODEL_ARCHIVE_MAP from transformers.modeling_xxx import XXX_PRETRAINED_MODEL_ARCHIVE_MAP
@require_torch @require_torch
class XxxModelTest(CommonTestCases.CommonModelTester): class XxxModelTest(CommonTestCases.CommonModelTester):
all_model_classes = (XxxModel, XxxForMaskedLM, XxxForQuestionAnswering, all_model_classes = (
XxxForSequenceClassification, (XxxModel, XxxForMaskedLM, XxxForQuestionAnswering, XxxForSequenceClassification, XxxForTokenClassification)
XxxForTokenClassification) if is_torch_available() else () if is_torch_available()
else ()
)
class XxxModelTester(object): class XxxModelTester(object):
def __init__(
def __init__(self, self,
parent, parent,
batch_size=13, batch_size=13,
seq_length=7, seq_length=7,
is_training=True, is_training=True,
use_input_mask=True, use_input_mask=True,
use_token_type_ids=True, use_token_type_ids=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
hidden_size=32, hidden_size=32,
num_hidden_layers=5, num_hidden_layers=5,
num_attention_heads=4, num_attention_heads=4,
intermediate_size=37, intermediate_size=37,
hidden_act="gelu", hidden_act="gelu",
hidden_dropout_prob=0.1, hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1, attention_probs_dropout_prob=0.1,
max_position_embeddings=512, max_position_embeddings=512,
type_vocab_size=16, type_vocab_size=16,
type_sequence_label_size=2, type_sequence_label_size=2,
initializer_range=0.02, initializer_range=0.02,
num_labels=3, num_labels=3,
num_choices=4, num_choices=4,
scope=None, scope=None,
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
self.seq_length = seq_length self.seq_length = seq_length
...@@ -118,16 +127,17 @@ class XxxModelTest(CommonTestCases.CommonModelTester): ...@@ -118,16 +127,17 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range) initializer_range=self.initializer_range,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def check_loss_output(self, result): def check_loss_output(self, result):
self.parent.assertListEqual( self.parent.assertListEqual(list(result["loss"].size()), [])
list(result["loss"].size()),
[])
def create_and_check_xxx_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): def create_and_check_xxx_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = XxxModel(config=config) model = XxxModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -140,83 +150,98 @@ class XxxModelTest(CommonTestCases.CommonModelTester): ...@@ -140,83 +150,98 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
"pooled_output": pooled_output, "pooled_output": pooled_output,
} }
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
[self.batch_size, self.seq_length, self.hidden_size]) )
self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_xxx_for_masked_lm(
def create_and_check_xxx_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = XxxForMaskedLM(config=config) model = XxxForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels) loss, prediction_scores = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, masked_lm_labels=token_labels
)
result = { result = {
"loss": loss, "loss": loss,
"prediction_scores": prediction_scores, "prediction_scores": prediction_scores,
} }
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["prediction_scores"].size()), list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]
[self.batch_size, self.seq_length, self.vocab_size]) )
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_xxx_for_question_answering(
def create_and_check_xxx_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = XxxForQuestionAnswering(config=config) model = XxxForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, loss, start_logits, end_logits = model(
start_positions=sequence_labels, end_positions=sequence_labels) input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = { result = {
"loss": loss, "loss": loss,
"start_logits": start_logits, "start_logits": start_logits,
"end_logits": end_logits, "end_logits": end_logits,
} }
self.parent.assertListEqual( self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
list(result["start_logits"].size()), self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
[self.batch_size, self.seq_length])
self.parent.assertListEqual(
list(result["end_logits"].size()),
[self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_xxx_for_sequence_classification(
def create_and_check_xxx_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels config.num_labels = self.num_labels
model = XxxForSequenceClassification(config) model = XxxForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
)
result = { result = {
"loss": loss, "loss": loss,
"logits": logits, "logits": logits,
} }
self.parent.assertListEqual( self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels])
list(result["logits"].size()),
[self.batch_size, self.num_labels])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_xxx_for_token_classification(
def create_and_check_xxx_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels config.num_labels = self.num_labels
model = XxxForTokenClassification(config=config) model = XxxForTokenClassification(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels) loss, logits = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels
)
result = { result = {
"loss": loss, "loss": loss,
"logits": logits, "logits": logits,
} }
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["logits"].size()), list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels]
[self.batch_size, self.seq_length, self.num_labels]) )
self.check_loss_output(result) self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, token_type_ids, input_mask, (
sequence_labels, token_labels, choice_labels) = config_and_inputs config,
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask} input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict return config, inputs_dict
def setUp(self): def setUp(self):
...@@ -252,5 +277,6 @@ class XxxModelTest(CommonTestCases.CommonModelTester): ...@@ -252,5 +277,6 @@ class XxxModelTest(CommonTestCases.CommonModelTester):
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR) model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model) self.assertIsNotNone(model)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -18,10 +18,11 @@ import os ...@@ -18,10 +18,11 @@ import os
import unittest import unittest
from io import open from io import open
from transformers.tokenization_bert import (XxxTokenizer, VOCAB_FILES_NAMES) from transformers.tokenization_bert import XxxTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases from .tokenization_tests_commons import CommonTestCases
class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester): class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
tokenizer_class = XxxTokenizer tokenizer_class = XxxTokenizer
...@@ -30,28 +31,39 @@ class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -30,28 +31,39 @@ class XxxTokenizationTest(CommonTestCases.CommonTokenizerTester):
super(XxxTokenizationTest, self).setUp() super(XxxTokenizationTest, self).setUp()
vocab_tokens = [ vocab_tokens = [
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "[UNK]",
"##ing", ",", "low", "lowest", "[CLS]",
"[SEP]",
"want",
"##want",
"##ed",
"wa",
"un",
"runn",
"##ing",
",",
"low",
"lowest",
] ]
self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES['vocab_file']) self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
with open(self.vocab_file, "w", encoding='utf-8') as vocab_writer: with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self, **kwargs): def get_tokenizer(self, **kwargs):
return XxxTokenizer.from_pretrained(self.tmpdirname, **kwargs) return XxxTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self): def get_input_output_texts(self):
input_text = u"UNwant\u00E9d,running" input_text = "UNwant\u00E9d,running"
output_text = u"unwanted, running" output_text = "unwanted, running"
return input_text, output_text return input_text, output_text
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file) tokenizer = self.tokenizer_class(self.vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") tokens = tokenizer.tokenize("UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"]) self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9]) self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -34,17 +34,16 @@ logger = logging.getLogger(__name__) ...@@ -34,17 +34,16 @@ logger = logging.getLogger(__name__)
# Mapping from the keyword arguments names of Tokenizer `__init__` # Mapping from the keyword arguments names of Tokenizer `__init__`
# to file names for serializing Tokenizer instances # to file names for serializing Tokenizer instances
#################################################### ####################################################
VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
#################################################### ####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__` # Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names. # to pretrained vocabulary URL for all the model shortcut names.
#################################################### ####################################################
PRETRAINED_VOCAB_FILES_MAP = { PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': "vocab_file": {
{ "xxx-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-vocab.txt",
'xxx-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-base-uncased-vocab.txt", "xxx-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-vocab.txt",
'xxx-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/xxx-large-uncased-vocab.txt",
} }
} }
...@@ -52,8 +51,8 @@ PRETRAINED_VOCAB_FILES_MAP = { ...@@ -52,8 +51,8 @@ PRETRAINED_VOCAB_FILES_MAP = {
# Mapping from model shortcut names to max length of inputs # Mapping from model shortcut names to max length of inputs
#################################################### ####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xxx-base-uncased': 512, "xxx-base-uncased": 512,
'xxx-large-uncased': 512, "xxx-large-uncased": 512,
} }
#################################################### ####################################################
...@@ -62,8 +61,8 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { ...@@ -62,8 +61,8 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
# To be used for checkpoint specific configurations. # To be used for checkpoint specific configurations.
#################################################### ####################################################
PRETRAINED_INIT_CONFIGURATION = { PRETRAINED_INIT_CONFIGURATION = {
'xxx-base-uncased': {'do_lower_case': True}, "xxx-base-uncased": {"do_lower_case": True},
'xxx-large-uncased': {'do_lower_case': True}, "xxx-large-uncased": {"do_lower_case": True},
} }
...@@ -73,7 +72,7 @@ def load_vocab(vocab_file): ...@@ -73,7 +72,7 @@ def load_vocab(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as reader: with open(vocab_file, "r", encoding="utf-8") as reader:
tokens = reader.readlines() tokens = reader.readlines()
for index, token in enumerate(tokens): for index, token in enumerate(tokens):
token = token.rstrip('\n') token = token.rstrip("\n")
vocab[token] = index vocab[token] = index
return vocab return vocab
...@@ -93,9 +92,17 @@ class XxxTokenizer(PreTrainedTokenizer): ...@@ -93,9 +92,17 @@ class XxxTokenizer(PreTrainedTokenizer):
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, do_lower_case=True, def __init__(
unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", self,
mask_token="[MASK]", **kwargs): vocab_file,
do_lower_case=True,
unk_token="[UNK]",
sep_token="[SEP]",
pad_token="[PAD]",
cls_token="[CLS]",
mask_token="[MASK]",
**kwargs
):
"""Constructs a XxxTokenizer. """Constructs a XxxTokenizer.
Args: Args:
...@@ -104,16 +111,22 @@ class XxxTokenizer(PreTrainedTokenizer): ...@@ -104,16 +111,22 @@ class XxxTokenizer(PreTrainedTokenizer):
Whether to lower case the input Whether to lower case the input
Only has an effect when do_basic_tokenize=True Only has an effect when do_basic_tokenize=True
""" """
super(XxxTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, super(XxxTokenizer, self).__init__(
pad_token=pad_token, cls_token=cls_token, unk_token=unk_token,
mask_token=mask_token, **kwargs) sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
if not os.path.isfile(vocab_file): if not os.path.isfile(vocab_file):
raise ValueError( raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) "model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)
)
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
@property @property
...@@ -142,7 +155,7 @@ class XxxTokenizer(PreTrainedTokenizer): ...@@ -142,7 +155,7 @@ class XxxTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """ Converts a sequence of tokens (string) in a single string. """
out_string = ' '.join(tokens).replace(' ##', '').strip() out_string = " ".join(tokens).replace(" ##", "").strip()
return out_string return out_string
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
...@@ -177,8 +190,10 @@ class XxxTokenizer(PreTrainedTokenizer): ...@@ -177,8 +190,10 @@ class XxxTokenizer(PreTrainedTokenizer):
if already_has_special_tokens: if already_has_special_tokens:
if token_ids_1 is not None: if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of " raise ValueError(
"ids is already formated with special tokens for the model.") "You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None: if token_ids_1 is not None:
...@@ -204,15 +219,17 @@ class XxxTokenizer(PreTrainedTokenizer): ...@@ -204,15 +219,17 @@ class XxxTokenizer(PreTrainedTokenizer):
"""Save the tokenizer vocabulary to a directory or file.""" """Save the tokenizer vocabulary to a directory or file."""
index = 0 index = 0
if os.path.isdir(vocab_path): if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
else: else:
vocab_file = vocab_path vocab_file = vocab_path
with open(vocab_file, "w", encoding="utf-8") as writer: with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
if index != token_index: if index != token_index:
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." logger.warning(
" Please check that the vocabulary is not corrupted!".format(vocab_file)) "Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!".format(vocab_file)
)
index = token_index index = token_index
writer.write(token + u'\n') writer.write(token + "\n")
index += 1 index += 1
return (vocab_file,) return (vocab_file,)
...@@ -6,8 +6,9 @@ __version__ = "2.3.0" ...@@ -6,8 +6,9 @@ __version__ = "2.3.0"
# and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493 # and: https://github.com/tensorflow/tensorflow/issues/26691#issuecomment-500369493
try: try:
import absl.logging import absl.logging
absl.logging.set_verbosity('info')
absl.logging.set_stderrthreshold('info') absl.logging.set_verbosity("info")
absl.logging.set_stderrthreshold("info")
absl.logging._warn_preinit_stderr = False absl.logging._warn_preinit_stderr = False
except: except:
pass pass
...@@ -17,19 +18,41 @@ import logging ...@@ -17,19 +18,41 @@ import logging
logger = logging.getLogger(__name__) # pylint: disable=invalid-name logger = logging.getLogger(__name__) # pylint: disable=invalid-name
# Files and general utilities # Files and general utilities
from .file_utils import (TRANSFORMERS_CACHE, PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, from .file_utils import (
cached_path, add_start_docstrings, add_end_docstrings, TRANSFORMERS_CACHE,
WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME, CONFIG_NAME, MODEL_CARD_NAME, PYTORCH_TRANSFORMERS_CACHE,
is_tf_available, is_torch_available) PYTORCH_PRETRAINED_BERT_CACHE,
cached_path,
from .data import (is_sklearn_available, add_start_docstrings,
InputExample, InputFeatures, DataProcessor, add_end_docstrings,
SingleSentenceClassificationProcessor, WEIGHTS_NAME,
glue_output_modes, glue_convert_examples_to_features, TF2_WEIGHTS_NAME,
glue_processors, glue_tasks_num_labels, TF_WEIGHTS_NAME,
xnli_output_modes, xnli_processors, xnli_tasks_num_labels, CONFIG_NAME,
squad_convert_examples_to_features, SquadFeatures, MODEL_CARD_NAME,
SquadExample, SquadV1Processor, SquadV2Processor) is_tf_available,
is_torch_available,
)
from .data import (
is_sklearn_available,
InputExample,
InputFeatures,
DataProcessor,
SingleSentenceClassificationProcessor,
glue_output_modes,
glue_convert_examples_to_features,
glue_processors,
glue_tasks_num_labels,
xnli_output_modes,
xnli_processors,
xnli_tasks_num_labels,
squad_convert_examples_to_features,
SquadFeatures,
SquadExample,
SquadV1Processor,
SquadV2Processor,
)
if is_sklearn_available(): if is_sklearn_available():
from .data import glue_compute_metrics, xnli_compute_metrics from .data import glue_compute_metrics, xnli_compute_metrics
...@@ -38,12 +61,12 @@ if is_sklearn_available(): ...@@ -38,12 +61,12 @@ if is_sklearn_available():
from .modelcard import ModelCard from .modelcard import ModelCard
# Tokenizers # Tokenizers
from .tokenization_utils import (PreTrainedTokenizer) from .tokenization_utils import PreTrainedTokenizer
from .tokenization_auto import AutoTokenizer from .tokenization_auto import AutoTokenizer
from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer, MecabTokenizer, CharacterTokenizer from .tokenization_bert_japanese import BertJapaneseTokenizer, MecabTokenizer, CharacterTokenizer
from .tokenization_openai import OpenAIGPTTokenizer from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLCorpus
from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_ctrl import CTRLTokenizer from .tokenization_ctrl import CTRLTokenizer
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
...@@ -75,143 +98,281 @@ from .configuration_mmbt import MMBTConfig ...@@ -75,143 +98,281 @@ from .configuration_mmbt import MMBTConfig
# Modeling # Modeling
if is_torch_available(): if is_torch_available():
from .modeling_utils import (PreTrainedModel, prune_layer, Conv1D) from .modeling_utils import PreTrainedModel, prune_layer, Conv1D
from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, from .modeling_auto import (
AutoModelWithLMHead, AutoModelForTokenClassification, ALL_PRETRAINED_MODEL_ARCHIVE_MAP) AutoModel,
AutoModelForSequenceClassification,
from .modeling_bert import (BertPreTrainedModel, BertModel, BertForPreTraining, AutoModelForQuestionAnswering,
BertForMaskedLM, BertForNextSentencePrediction, AutoModelWithLMHead,
BertForSequenceClassification, BertForMultipleChoice, AutoModelForTokenClassification,
BertForTokenClassification, BertForQuestionAnswering, ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP) )
from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, from .modeling_bert import (
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) BertPreTrainedModel,
from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel, BertModel,
AdaptiveEmbedding, BertForPreTraining,
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) BertForMaskedLM,
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model, BertForNextSentencePrediction,
GPT2LMHeadModel, GPT2DoubleHeadsModel, BertForSequenceClassification,
load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) BertForMultipleChoice,
from .modeling_ctrl import (CTRLPreTrainedModel, CTRLModel, BertForTokenClassification,
CTRLLMHeadModel, BertForQuestionAnswering,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) load_tf_weights_in_bert,
from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNetForSequenceClassification, XLNetForTokenClassification, )
XLNetForMultipleChoice, XLNetForQuestionAnsweringSimple, from .modeling_openai import (
XLNetForQuestionAnswering, load_tf_weights_in_xlnet, OpenAIGPTPreTrainedModel,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) OpenAIGPTModel,
from .modeling_xlm import (XLMPreTrainedModel , XLMModel, OpenAIGPTLMHeadModel,
XLMWithLMHeadModel, XLMForSequenceClassification, OpenAIGPTDoubleHeadsModel,
XLMForQuestionAnswering, XLMForQuestionAnsweringSimple, load_tf_weights_in_openai_gpt,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP) OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, )
RobertaForSequenceClassification, RobertaForMultipleChoice, from .modeling_transfo_xl import (
RobertaForTokenClassification, RobertaForQuestionAnswering, TransfoXLPreTrainedModel,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) TransfoXLModel,
from .modeling_distilbert import (DistilBertPreTrainedModel, DistilBertForMaskedLM, DistilBertModel, TransfoXLLMHeadModel,
DistilBertForSequenceClassification, DistilBertForQuestionAnswering, AdaptiveEmbedding,
DistilBertForTokenClassification, load_tf_weights_in_transfo_xl,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
from .modeling_camembert import (CamembertForMaskedLM, CamembertModel, )
CamembertForSequenceClassification, CamembertForMultipleChoice, from .modeling_gpt2 import (
CamembertForTokenClassification, GPT2PreTrainedModel,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP) GPT2Model,
GPT2LMHeadModel,
GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_ctrl import CTRLPreTrainedModel, CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
from .modeling_xlnet import (
XLNetPreTrainedModel,
XLNetModel,
XLNetLMHeadModel,
XLNetForSequenceClassification,
XLNetForTokenClassification,
XLNetForMultipleChoice,
XLNetForQuestionAnsweringSimple,
XLNetForQuestionAnswering,
load_tf_weights_in_xlnet,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_xlm import (
XLMPreTrainedModel,
XLMModel,
XLMWithLMHeadModel,
XLMForSequenceClassification,
XLMForQuestionAnswering,
XLMForQuestionAnsweringSimple,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_roberta import (
RobertaForMaskedLM,
RobertaModel,
RobertaForSequenceClassification,
RobertaForMultipleChoice,
RobertaForTokenClassification,
RobertaForQuestionAnswering,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_distilbert import (
DistilBertPreTrainedModel,
DistilBertForMaskedLM,
DistilBertModel,
DistilBertForSequenceClassification,
DistilBertForQuestionAnswering,
DistilBertForTokenClassification,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_camembert import (
CamembertForMaskedLM,
CamembertModel,
CamembertForSequenceClassification,
CamembertForMultipleChoice,
CamembertForTokenClassification,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
from .modeling_t5 import (T5PreTrainedModel, T5Model, T5WithLMHeadModel, from .modeling_t5 import (
load_tf_weights_in_t5, T5PreTrainedModel,
T5_PRETRAINED_MODEL_ARCHIVE_MAP) T5Model,
from .modeling_albert import (AlbertPreTrainedModel, AlbertModel, AlbertForMaskedLM, AlbertForSequenceClassification, T5WithLMHeadModel,
AlbertForQuestionAnswering, load_tf_weights_in_t5,
load_tf_weights_in_albert, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP) T5_PRETRAINED_MODEL_ARCHIVE_MAP,
from .modeling_xlm_roberta import (XLMRobertaForMaskedLM, XLMRobertaModel, XLMRobertaForMultipleChoice, )
XLMRobertaForSequenceClassification, XLMRobertaForTokenClassification) from .modeling_albert import (
AlbertPreTrainedModel,
AlbertModel,
AlbertForMaskedLM,
AlbertForSequenceClassification,
AlbertForQuestionAnswering,
load_tf_weights_in_albert,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_xlm_roberta import (
XLMRobertaForMaskedLM,
XLMRobertaModel,
XLMRobertaForMultipleChoice,
XLMRobertaForSequenceClassification,
XLMRobertaForTokenClassification,
)
from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification
# Optimization # Optimization
from .optimization import (AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, from .optimization import (
get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup) AdamW,
get_constant_schedule,
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
# TensorFlow # TensorFlow
if is_tf_available(): if is_tf_available():
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, TFSequenceSummary, shape_list
from .modeling_tf_auto import (TFAutoModel, TFAutoModelForSequenceClassification, TFAutoModelForQuestionAnswering, from .modeling_tf_auto import (
TFAutoModelWithLMHead, TFAutoModelForTokenClassification, TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP) TFAutoModel,
TFAutoModelForSequenceClassification,
from .modeling_tf_bert import (TFBertPreTrainedModel, TFBertMainLayer, TFBertEmbeddings, TFAutoModelForQuestionAnswering,
TFBertModel, TFBertForPreTraining, TFAutoModelWithLMHead,
TFBertForMaskedLM, TFBertForNextSentencePrediction, TFAutoModelForTokenClassification,
TFBertForSequenceClassification, TFBertForMultipleChoice, TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
TFBertForTokenClassification, TFBertForQuestionAnswering, )
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_tf_bert import (
from .modeling_tf_gpt2 import (TFGPT2PreTrainedModel, TFGPT2MainLayer, TFBertPreTrainedModel,
TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel, TFBertMainLayer,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) TFBertEmbeddings,
TFBertModel,
from .modeling_tf_openai import (TFOpenAIGPTPreTrainedModel, TFOpenAIGPTMainLayer, TFBertForPreTraining,
TFOpenAIGPTModel, TFOpenAIGPTLMHeadModel, TFOpenAIGPTDoubleHeadsModel, TFBertForMaskedLM,
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) TFBertForNextSentencePrediction,
TFBertForSequenceClassification,
from .modeling_tf_transfo_xl import (TFTransfoXLPreTrainedModel, TFTransfoXLMainLayer, TFBertForMultipleChoice,
TFTransfoXLModel, TFTransfoXLLMHeadModel, TFBertForTokenClassification,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) TFBertForQuestionAnswering,
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
from .modeling_tf_xlnet import (TFXLNetPreTrainedModel, TFXLNetMainLayer, )
TFXLNetModel, TFXLNetLMHeadModel,
TFXLNetForSequenceClassification, from .modeling_tf_gpt2 import (
TFXLNetForTokenClassification, TFGPT2PreTrainedModel,
TFXLNetForQuestionAnsweringSimple, TFGPT2MainLayer,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) TFGPT2Model,
TFGPT2LMHeadModel,
from .modeling_tf_xlm import (TFXLMPreTrainedModel, TFXLMMainLayer, TFGPT2DoubleHeadsModel,
TFXLMModel, TFXLMWithLMHeadModel, TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
TFXLMForSequenceClassification, )
TFXLMForQuestionAnsweringSimple,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_tf_openai import (
TFOpenAIGPTPreTrainedModel,
from .modeling_tf_roberta import (TFRobertaPreTrainedModel, TFRobertaMainLayer, TFOpenAIGPTMainLayer,
TFRobertaModel, TFRobertaForMaskedLM, TFOpenAIGPTModel,
TFRobertaForSequenceClassification, TFOpenAIGPTLMHeadModel,
TFRobertaForTokenClassification, TFOpenAIGPTDoubleHeadsModel,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer,
TFDistilBertModel, TFDistilBertForMaskedLM, from .modeling_tf_transfo_xl import (
TFDistilBertForSequenceClassification, TFTransfoXLPreTrainedModel,
TFDistilBertForTokenClassification, TFTransfoXLMainLayer,
TFDistilBertForQuestionAnswering, TFTransfoXLModel,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) TFTransfoXLLMHeadModel,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
from .modeling_tf_ctrl import (TFCTRLPreTrainedModel, TFCTRLModel, )
TFCTRLLMHeadModel,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_tf_xlnet import (
TFXLNetPreTrainedModel,
from .modeling_tf_albert import (TFAlbertPreTrainedModel, TFAlbertModel, TFAlbertForMaskedLM, TFXLNetMainLayer,
TFAlbertForSequenceClassification, TFXLNetModel,
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP) TFXLNetLMHeadModel,
TFXLNetForSequenceClassification,
from .modeling_tf_t5 import (TFT5PreTrainedModel, TFT5Model, TFT5WithLMHeadModel, TFXLNetForTokenClassification,
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP) TFXLNetForQuestionAnsweringSimple,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_tf_xlm import (
TFXLMPreTrainedModel,
TFXLMMainLayer,
TFXLMModel,
TFXLMWithLMHeadModel,
TFXLMForSequenceClassification,
TFXLMForQuestionAnsweringSimple,
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_tf_roberta import (
TFRobertaPreTrainedModel,
TFRobertaMainLayer,
TFRobertaModel,
TFRobertaForMaskedLM,
TFRobertaForSequenceClassification,
TFRobertaForTokenClassification,
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_tf_distilbert import (
TFDistilBertPreTrainedModel,
TFDistilBertMainLayer,
TFDistilBertModel,
TFDistilBertForMaskedLM,
TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification,
TFDistilBertForQuestionAnswering,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_tf_ctrl import (
TFCTRLPreTrainedModel,
TFCTRLModel,
TFCTRLLMHeadModel,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_tf_albert import (
TFAlbertPreTrainedModel,
TFAlbertModel,
TFAlbertForMaskedLM,
TFAlbertForSequenceClassification,
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_tf_t5 import TFT5PreTrainedModel, TFT5Model, TFT5WithLMHeadModel, TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
# Optimization # Optimization
from .optimization_tf import (WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator) from .optimization_tf import WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator
# TF 2.0 <=> PyTorch conversion utilities # TF 2.0 <=> PyTorch conversion utilities
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name, from .modeling_tf_pytorch_utils import (
load_pytorch_checkpoint_in_tf2_model, convert_tf_weight_name_to_pt_weight_name,
load_pytorch_weights_in_tf2_model, load_pytorch_checkpoint_in_tf2_model,
load_pytorch_model_in_tf2_model, load_pytorch_weights_in_tf2_model,
load_tf2_checkpoint_in_pytorch_model, load_pytorch_model_in_tf2_model,
load_tf2_weights_in_pytorch_model, load_tf2_checkpoint_in_pytorch_model,
load_tf2_model_in_pytorch_model) load_tf2_weights_in_pytorch_model,
load_tf2_model_in_pytorch_model,
)
# Pipelines # Pipelines
from .pipelines import pipeline, PipelineDataFormat, CsvPipelineDataFormat, JsonPipelineDataFormat, PipedPipelineDataFormat, \ from .pipelines import (
Pipeline, FeatureExtractionPipeline, QuestionAnsweringPipeline, NerPipeline, TextClassificationPipeline pipeline,
PipelineDataFormat,
CsvPipelineDataFormat,
JsonPipelineDataFormat,
PipedPipelineDataFormat,
Pipeline,
FeatureExtractionPipeline,
QuestionAnsweringPipeline,
NerPipeline,
TextClassificationPipeline,
)
if not is_tf_available() and not is_torch_available(): if not is_tf_available() and not is_torch_available():
logger.warning("Neither PyTorch nor TensorFlow >= 2.0 have been found." logger.warning(
"Models won't be available and only tokenizers, configuration" "Neither PyTorch nor TensorFlow >= 2.0 have been found."
"and file/data utilities can be used.") "Models won't be available and only tokenizers, configuration"
"and file/data utilities can be used."
)
# coding: utf8 # coding: utf8
def main(): def main():
import sys import sys
if len(sys.argv) < 2 or sys.argv[1] not in ["convert", "train", "predict", "serve"]: if len(sys.argv) < 2 or sys.argv[1] not in ["convert", "train", "predict", "serve"]:
print( print(
"First argument to `transformers` command line interface should be one of: \n" "First argument to `transformers` command line interface should be one of: \n"
">> convert serve train predict") ">> convert serve train predict"
)
if sys.argv[1] == "convert": if sys.argv[1] == "convert":
from transformers.commands import convert from transformers.commands import convert
convert(sys.argv) convert(sys.argv)
elif sys.argv[1] == "train": elif sys.argv[1] == "train":
from transformers.commands import train from transformers.commands import train
train(sys.argv) train(sys.argv)
elif sys.argv[1] == "serve": elif sys.argv[1] == "serve":
pass pass
...@@ -19,7 +24,6 @@ def main(): ...@@ -19,7 +24,6 @@ def main():
# parser = ArgumentParser('Transformers CLI tool', usage='transformers serve <command> [<args>]') # parser = ArgumentParser('Transformers CLI tool', usage='transformers serve <command> [<args>]')
# commands_parser = parser.add_subparsers(help='transformers-cli command helpers') # commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
# # Register commands # # Register commands
# ServeCommand.register_subcommand(commands_parser) # ServeCommand.register_subcommand(commands_parser)
...@@ -33,5 +37,6 @@ def main(): ...@@ -33,5 +37,6 @@ def main():
# service = args.func(args) # service = args.func(args)
# service.run() # service.run()
if __name__ == '__main__':
if __name__ == "__main__":
main() main()
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from argparse import ArgumentParser from argparse import ArgumentParser
class BaseTransformersCLICommand(ABC): class BaseTransformersCLICommand(ABC):
@staticmethod @staticmethod
@abstractmethod @abstractmethod
......
...@@ -11,12 +11,12 @@ def convert_command_factory(args: Namespace): ...@@ -11,12 +11,12 @@ def convert_command_factory(args: Namespace):
Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint. Factory function used to convert a model TF 1.0 checkpoint in a PyTorch checkpoint.
:return: ServeCommand :return: ServeCommand
""" """
return ConvertCommand(args.model_type, args.tf_checkpoint, args.pytorch_dump_output, return ConvertCommand(
args.config, args.finetuning_task_name) args.model_type, args.tf_checkpoint, args.pytorch_dump_output, args.config, args.finetuning_task_name
)
class ConvertCommand(BaseTransformersCLICommand): class ConvertCommand(BaseTransformersCLICommand):
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
""" """
...@@ -24,25 +24,39 @@ class ConvertCommand(BaseTransformersCLICommand): ...@@ -24,25 +24,39 @@ class ConvertCommand(BaseTransformersCLICommand):
:param parser: Root parser to register command-specific arguments :param parser: Root parser to register command-specific arguments
:return: :return:
""" """
train_parser = parser.add_parser('convert', help="CLI tool to run convert model from original " train_parser = parser.add_parser(
"author checkpoints to Transformesr PyTorch checkpoints.") "convert",
train_parser.add_argument('--model_type', type=str, required=True, help="CLI tool to run convert model from original "
help='Model\'s type.') "author checkpoints to Transformesr PyTorch checkpoints.",
train_parser.add_argument('--tf_checkpoint', type=str, required=True, )
help='TensorFlow checkpoint path or folder.') train_parser.add_argument("--model_type", type=str, required=True, help="Model's type.")
train_parser.add_argument('--pytorch_dump_output', type=str, required=True, train_parser.add_argument(
help='Path to the PyTorch savd model output.') "--tf_checkpoint", type=str, required=True, help="TensorFlow checkpoint path or folder."
train_parser.add_argument('--config', type=str, default="", )
help='Configuration file path or folder.') train_parser.add_argument(
train_parser.add_argument('--finetuning_task_name', type=str, default=None, "--pytorch_dump_output", type=str, required=True, help="Path to the PyTorch savd model output."
help='Optional fine-tuning task name if the TF model was a finetuned model.') )
train_parser.add_argument("--config", type=str, default="", help="Configuration file path or folder.")
train_parser.add_argument(
"--finetuning_task_name",
type=str,
default=None,
help="Optional fine-tuning task name if the TF model was a finetuned model.",
)
train_parser.set_defaults(func=convert_command_factory) train_parser.set_defaults(func=convert_command_factory)
def __init__(self, model_type: str, tf_checkpoint: str, pytorch_dump_output: str, def __init__(
config: str, finetuning_task_name: str, *args): self,
self._logger = getLogger('transformers-cli/converting') model_type: str,
tf_checkpoint: str,
pytorch_dump_output: str,
config: str,
finetuning_task_name: str,
*args
):
self._logger = getLogger("transformers-cli/converting")
self._logger.info('Loading model {}'.format(model_type)) self._logger.info("Loading model {}".format(model_type))
self._model_type = model_type self._model_type = model_type
self._tf_checkpoint = tf_checkpoint self._tf_checkpoint = tf_checkpoint
self._pytorch_dump_output = pytorch_dump_output self._pytorch_dump_output = pytorch_dump_output
...@@ -52,63 +66,80 @@ class ConvertCommand(BaseTransformersCLICommand): ...@@ -52,63 +66,80 @@ class ConvertCommand(BaseTransformersCLICommand):
def run(self): def run(self):
if self._model_type == "bert": if self._model_type == "bert":
try: try:
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
convert_tf_checkpoint_to_pytorch,
)
except ImportError: except ImportError:
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \ msg = (
"In that case, it requires TensorFlow to be installed. Please see " \ "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions." "https://www.tensorflow.org/install/ for installation instructions."
)
raise ImportError(msg) raise ImportError(msg)
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "gpt": elif self._model_type == "gpt":
from transformers.convert_openai_original_tf_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch from transformers.convert_openai_original_tf_checkpoint_to_pytorch import (
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, convert_openai_checkpoint_to_pytorch,
self._config, )
self._pytorch_dump_output)
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "transfo_xl": elif self._model_type == "transfo_xl":
try: try:
from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch from transformers.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
convert_transfo_xl_checkpoint_to_pytorch,
)
except ImportError: except ImportError:
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \ msg = (
"In that case, it requires TensorFlow to be installed. Please see " \ "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions." "https://www.tensorflow.org/install/ for installation instructions."
)
raise ImportError(msg) raise ImportError(msg)
if 'ckpt' in self._tf_checkpoint.lower(): if "ckpt" in self._tf_checkpoint.lower():
TF_CHECKPOINT = self._tf_checkpoint TF_CHECKPOINT = self._tf_checkpoint
TF_DATASET_FILE = "" TF_DATASET_FILE = ""
else: else:
TF_DATASET_FILE = self._tf_checkpoint TF_DATASET_FILE = self._tf_checkpoint
TF_CHECKPOINT = "" TF_CHECKPOINT = ""
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, convert_transfo_xl_checkpoint_to_pytorch(
self._config, TF_CHECKPOINT, self._config, self._pytorch_dump_output, TF_DATASET_FILE
self._pytorch_dump_output, )
TF_DATASET_FILE)
elif self._model_type == "gpt2": elif self._model_type == "gpt2":
try: try:
from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch from transformers.convert_gpt2_original_tf_checkpoint_to_pytorch import (
convert_gpt2_checkpoint_to_pytorch,
)
except ImportError: except ImportError:
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \ msg = (
"In that case, it requires TensorFlow to be installed. Please see " \ "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions." "https://www.tensorflow.org/install/ for installation instructions."
)
raise ImportError(msg) raise ImportError(msg)
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "xlnet": elif self._model_type == "xlnet":
try: try:
from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch from transformers.convert_xlnet_original_tf_checkpoint_to_pytorch import (
convert_xlnet_checkpoint_to_pytorch,
)
except ImportError: except ImportError:
msg = "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, " \ msg = (
"In that case, it requires TensorFlow to be installed. Please see " \ "transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions." "https://www.tensorflow.org/install/ for installation instructions."
)
raise ImportError(msg) raise ImportError(msg)
convert_xlnet_checkpoint_to_pytorch(self._tf_checkpoint, convert_xlnet_checkpoint_to_pytorch(
self._config, self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
self._pytorch_dump_output, )
self._finetuning_task_name)
elif self._model_type == "xlm": elif self._model_type == "xlm":
from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import convert_xlm_checkpoint_to_pytorch from transformers.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
convert_xlm_checkpoint_to_pytorch,
)
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output) convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
else: else:
......
...@@ -8,13 +8,16 @@ def download_command_factory(args): ...@@ -8,13 +8,16 @@ def download_command_factory(args):
class DownloadCommand(BaseTransformersCLICommand): class DownloadCommand(BaseTransformersCLICommand):
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
download_parser = parser.add_parser('download') download_parser = parser.add_parser("download")
download_parser.add_argument('--cache-dir', type=str, default=None, help='Path to location to store the models') download_parser.add_argument(
download_parser.add_argument('--force', action='store_true', help='Force the model to be download even if already in cache-dir') "--cache-dir", type=str, default=None, help="Path to location to store the models"
download_parser.add_argument('model', type=str, help='Name of the model to download') )
download_parser.add_argument(
"--force", action="store_true", help="Force the model to be download even if already in cache-dir"
)
download_parser.add_argument("model", type=str, help="Name of the model to download")
download_parser.set_defaults(func=download_command_factory) download_parser.set_defaults(func=download_command_factory)
def __init__(self, model: str, cache: str, force: bool): def __init__(self, model: str, cache: str, force: bool):
...@@ -26,4 +29,4 @@ class DownloadCommand(BaseTransformersCLICommand): ...@@ -26,4 +29,4 @@ class DownloadCommand(BaseTransformersCLICommand):
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
\ No newline at end of file
...@@ -10,52 +10,72 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name ...@@ -10,52 +10,72 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def try_infer_format_from_ext(path: str): def try_infer_format_from_ext(path: str):
if not path: if not path:
return 'pipe' return "pipe"
for ext in PipelineDataFormat.SUPPORTED_FORMATS: for ext in PipelineDataFormat.SUPPORTED_FORMATS:
if path.endswith(ext): if path.endswith(ext):
return ext return ext
raise Exception( raise Exception(
'Unable to determine file format from file extension {}. ' "Unable to determine file format from file extension {}. "
'Please provide the format through --format {}'.format(path, PipelineDataFormat.SUPPORTED_FORMATS) "Please provide the format through --format {}".format(path, PipelineDataFormat.SUPPORTED_FORMATS)
) )
def run_command_factory(args): def run_command_factory(args):
nlp = pipeline(task=args.task, nlp = pipeline(
model=args.model if args.model else None, task=args.task,
config=args.config, model=args.model if args.model else None,
tokenizer=args.tokenizer, config=args.config,
device=args.device) tokenizer=args.tokenizer,
format = try_infer_format_from_ext(args.input) if args.format == 'infer' else args.format device=args.device,
reader = PipelineDataFormat.from_str(format=format, )
output_path=args.output, format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
input_path=args.input, reader = PipelineDataFormat.from_str(
column=args.column if args.column else nlp.default_input_names, format=format,
overwrite=args.overwrite) output_path=args.output,
input_path=args.input,
column=args.column if args.column else nlp.default_input_names,
overwrite=args.overwrite,
)
return RunCommand(nlp, reader) return RunCommand(nlp, reader)
class RunCommand(BaseTransformersCLICommand): class RunCommand(BaseTransformersCLICommand):
def __init__(self, nlp: Pipeline, reader: PipelineDataFormat): def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
self._nlp = nlp self._nlp = nlp
self._reader = reader self._reader = reader
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
run_parser = parser.add_parser('run', help="Run a pipeline through the CLI") run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
run_parser.add_argument('--task', choices=SUPPORTED_TASKS.keys(), help='Task to run') run_parser.add_argument("--task", choices=SUPPORTED_TASKS.keys(), help="Task to run")
run_parser.add_argument('--input', type=str, help='Path to the file to use for inference') run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
run_parser.add_argument('--output', type=str, help='Path to the file that will be used post to write results.') run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
run_parser.add_argument('--model', type=str, help='Name or path to the model to instantiate.') run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
run_parser.add_argument('--config', type=str, help='Name or path to the model\'s config to instantiate.') run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.")
run_parser.add_argument('--tokenizer', type=str, help='Name of the tokenizer to use. (default: same as the model name)') run_parser.add_argument(
run_parser.add_argument('--column', type=str, help='Name of the column to use as input. (For multi columns input as QA use column1,columns2)') "--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)"
run_parser.add_argument('--format', type=str, default='infer', choices=PipelineDataFormat.SUPPORTED_FORMATS, help='Input format to read from') )
run_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)') run_parser.add_argument(
run_parser.add_argument('--overwrite', action='store_true', help='Allow overwriting the output file.') "--column",
type=str,
help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)",
)
run_parser.add_argument(
"--format",
type=str,
default="infer",
choices=PipelineDataFormat.SUPPORTED_FORMATS,
help="Input format to read from",
)
run_parser.add_argument(
"--device",
type=int,
default=-1,
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
)
run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.")
run_parser.set_defaults(func=run_command_factory) run_parser.set_defaults(func=run_command_factory)
def run(self): def run(self):
...@@ -71,9 +91,6 @@ class RunCommand(BaseTransformersCLICommand): ...@@ -71,9 +91,6 @@ class RunCommand(BaseTransformersCLICommand):
# Saving data # Saving data
if self._nlp.binary_output: if self._nlp.binary_output:
binary_path = self._reader.save_binary(outputs) binary_path = self._reader.save_binary(outputs)
logger.warning('Current pipeline requires output to be in binary format, saving at {}'.format(binary_path)) logger.warning("Current pipeline requires output to be in binary format, saving at {}".format(binary_path))
else: else:
self._reader.save(outputs) self._reader.save(outputs)
...@@ -7,6 +7,7 @@ try: ...@@ -7,6 +7,7 @@ try:
from uvicorn import run from uvicorn import run
from fastapi import FastAPI, HTTPException, Body from fastapi import FastAPI, HTTPException, Body
from pydantic import BaseModel from pydantic import BaseModel
_serve_dependancies_installed = True _serve_dependancies_installed = True
except (ImportError, AttributeError): except (ImportError, AttributeError):
BaseModel = object BaseModel = object
...@@ -17,18 +18,21 @@ from transformers import Pipeline ...@@ -17,18 +18,21 @@ from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline from transformers.pipelines import SUPPORTED_TASKS, pipeline
logger = logging.getLogger('transformers-cli/serving') logger = logging.getLogger("transformers-cli/serving")
def serve_command_factory(args: Namespace): def serve_command_factory(args: Namespace):
""" """
Factory function used to instantiate serving server from provided command line arguments. Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand :return: ServeCommand
""" """
nlp = pipeline(task=args.task, nlp = pipeline(
model=args.model if args.model else None, task=args.task,
config=args.config, model=args.model if args.model else None,
tokenizer=args.tokenizer, config=args.config,
device=args.device) tokenizer=args.tokenizer,
device=args.device,
)
return ServeCommand(nlp, args.host, args.port) return ServeCommand(nlp, args.host, args.port)
...@@ -36,6 +40,7 @@ class ServeModelInfoResult(BaseModel): ...@@ -36,6 +40,7 @@ class ServeModelInfoResult(BaseModel):
""" """
Expose model information Expose model information
""" """
infos: dict infos: dict
...@@ -43,6 +48,7 @@ class ServeTokenizeResult(BaseModel): ...@@ -43,6 +48,7 @@ class ServeTokenizeResult(BaseModel):
""" """
Tokenize result model Tokenize result model
""" """
tokens: List[str] tokens: List[str]
tokens_ids: Optional[List[int]] tokens_ids: Optional[List[int]]
...@@ -51,6 +57,7 @@ class ServeDeTokenizeResult(BaseModel): ...@@ -51,6 +57,7 @@ class ServeDeTokenizeResult(BaseModel):
""" """
DeTokenize result model DeTokenize result model
""" """
text: str text: str
...@@ -58,11 +65,11 @@ class ServeForwardResult(BaseModel): ...@@ -58,11 +65,11 @@ class ServeForwardResult(BaseModel):
""" """
Forward result model Forward result model
""" """
output: Any output: Any
class ServeCommand(BaseTransformersCLICommand): class ServeCommand(BaseTransformersCLICommand):
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
""" """
...@@ -70,14 +77,23 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -70,14 +77,23 @@ class ServeCommand(BaseTransformersCLICommand):
:param parser: Root parser to register command-specific arguments :param parser: Root parser to register command-specific arguments
:return: :return:
""" """
serve_parser = parser.add_parser('serve', help='CLI tool to run inference requests through REST and GraphQL endpoints.') serve_parser = parser.add_parser(
serve_parser.add_argument('--task', type=str, choices=SUPPORTED_TASKS.keys(), help='The task to run the pipeline on') "serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
serve_parser.add_argument('--host', type=str, default='localhost', help='Interface the server will listen on.') )
serve_parser.add_argument('--port', type=int, default=8888, help='Port the serving will listen to.') serve_parser.add_argument(
serve_parser.add_argument('--model', type=str, help='Model\'s name or path to stored model.') "--task", type=str, choices=SUPPORTED_TASKS.keys(), help="The task to run the pipeline on"
serve_parser.add_argument('--config', type=str, help='Model\'s config name or path to stored model.') )
serve_parser.add_argument('--tokenizer', type=str, help='Tokenizer name to use.') serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
serve_parser.add_argument('--device', type=int, default=-1, help='Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)') serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
serve_parser.add_argument(
"--device",
type=int,
default=-1,
help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
)
serve_parser.set_defaults(func=serve_command_factory) serve_parser.set_defaults(func=serve_command_factory)
def __init__(self, pipeline: Pipeline, host: str, port: int): def __init__(self, pipeline: Pipeline, host: str, port: int):
...@@ -87,18 +103,22 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -87,18 +103,22 @@ class ServeCommand(BaseTransformersCLICommand):
self._host = host self._host = host
self._port = port self._port = port
if not _serve_dependancies_installed: if not _serve_dependancies_installed:
raise ImportError("Using serve command requires FastAPI and unicorn. " raise ImportError(
"Please install transformers with [serving]: pip install transformers[serving]." "Using serve command requires FastAPI and unicorn. "
"Or install FastAPI and unicorn separatly.") "Please install transformers with [serving]: pip install transformers[serving]."
"Or install FastAPI and unicorn separatly."
)
else: else:
logger.info('Serving model over {}:{}'.format(host, port)) logger.info("Serving model over {}:{}".format(host, port))
self._app = FastAPI() self._app = FastAPI()
# Register routes # Register routes
self._app.add_api_route('/', self.model_info, response_model=ServeModelInfoResult, methods=['GET']) self._app.add_api_route("/", self.model_info, response_model=ServeModelInfoResult, methods=["GET"])
self._app.add_api_route('/tokenize', self.tokenize, response_model=ServeTokenizeResult, methods=['POST']) self._app.add_api_route("/tokenize", self.tokenize, response_model=ServeTokenizeResult, methods=["POST"])
self._app.add_api_route('/detokenize', self.detokenize, response_model=ServeDeTokenizeResult, methods=['POST']) self._app.add_api_route(
self._app.add_api_route('/forward', self.forward, response_model=ServeForwardResult, methods=['POST']) "/detokenize", self.detokenize, response_model=ServeDeTokenizeResult, methods=["POST"]
)
self._app.add_api_route("/forward", self.forward, response_model=ServeForwardResult, methods=["POST"])
def run(self): def run(self):
run(self._app, host=self._host, port=self._port) run(self._app, host=self._host, port=self._port)
...@@ -122,11 +142,14 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -122,11 +142,14 @@ class ServeCommand(BaseTransformersCLICommand):
return ServeTokenizeResult(tokens=tokens_txt) return ServeTokenizeResult(tokens=tokens_txt)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)}) raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
def detokenize(self, tokens_ids: List[int] = Body(None, embed=True), def detokenize(
skip_special_tokens: bool = Body(False, embed=True), self,
cleanup_tokenization_spaces: bool = Body(True, embed=True)): tokens_ids: List[int] = Body(None, embed=True),
skip_special_tokens: bool = Body(False, embed=True),
cleanup_tokenization_spaces: bool = Body(True, embed=True),
):
""" """
Detokenize the provided tokens ids to readable text: Detokenize the provided tokens ids to readable text:
- **tokens_ids**: List of tokens ids - **tokens_ids**: List of tokens ids
...@@ -135,9 +158,9 @@ class ServeCommand(BaseTransformersCLICommand): ...@@ -135,9 +158,9 @@ class ServeCommand(BaseTransformersCLICommand):
""" """
try: try:
decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces) decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
return ServeDeTokenizeResult(model='', text=decoded_str) return ServeDeTokenizeResult(model="", text=decoded_str)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail={"model": '', "error": str(e)}) raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)): def forward(self, inputs: Union[str, dict, List[str], List[int], List[dict]] = Body(None, embed=True)):
""" """
......
...@@ -3,9 +3,12 @@ from argparse import ArgumentParser, Namespace ...@@ -3,9 +3,12 @@ from argparse import ArgumentParser, Namespace
from logging import getLogger from logging import getLogger
from transformers.commands import BaseTransformersCLICommand from transformers.commands import BaseTransformersCLICommand
from transformers import (is_tf_available, is_torch_available, from transformers import (
TextClassificationPipeline, is_tf_available,
SingleSentenceClassificationProcessor as Processor) is_torch_available,
TextClassificationPipeline,
SingleSentenceClassificationProcessor as Processor,
)
if not is_tf_available() and not is_torch_available(): if not is_tf_available() and not is_torch_available():
raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training") raise ImportError("At least one of PyTorch or TensorFlow 2.0+ should be installed to use CLI training")
...@@ -14,6 +17,7 @@ if not is_tf_available() and not is_torch_available(): ...@@ -14,6 +17,7 @@ if not is_tf_available() and not is_torch_available():
USE_XLA = False USE_XLA = False
USE_AMP = False USE_AMP = False
def train_command_factory(args: Namespace): def train_command_factory(args: Namespace):
""" """
Factory function used to instantiate serving server from provided command line arguments. Factory function used to instantiate serving server from provided command line arguments.
...@@ -23,7 +27,6 @@ def train_command_factory(args: Namespace): ...@@ -23,7 +27,6 @@ def train_command_factory(args: Namespace):
class TrainCommand(BaseTransformersCLICommand): class TrainCommand(BaseTransformersCLICommand):
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
""" """
...@@ -31,47 +34,54 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -31,47 +34,54 @@ class TrainCommand(BaseTransformersCLICommand):
:param parser: Root parser to register command-specific arguments :param parser: Root parser to register command-specific arguments
:return: :return:
""" """
train_parser = parser.add_parser('train', help='CLI tool to train a model on a task.') train_parser = parser.add_parser("train", help="CLI tool to train a model on a task.")
train_parser.add_argument('--train_data', type=str, required=True, train_parser.add_argument(
help="path to train (and optionally evaluation) dataset as a csv with " "--train_data",
"tab separated labels and sentences.") type=str,
train_parser.add_argument('--column_label', type=int, default=0, required=True,
help='Column of the dataset csv file with example labels.') help="path to train (and optionally evaluation) dataset as a csv with "
train_parser.add_argument('--column_text', type=int, default=1, "tab separated labels and sentences.",
help='Column of the dataset csv file with example texts.') )
train_parser.add_argument('--column_id', type=int, default=2, train_parser.add_argument(
help='Column of the dataset csv file with example ids.') "--column_label", type=int, default=0, help="Column of the dataset csv file with example labels."
train_parser.add_argument('--skip_first_row', action='store_true', )
help='Skip the first row of the csv file (headers).') train_parser.add_argument(
"--column_text", type=int, default=1, help="Column of the dataset csv file with example texts."
train_parser.add_argument('--validation_data', type=str, default='', )
help='path to validation dataset.') train_parser.add_argument(
train_parser.add_argument('--validation_split', type=float, default=0.1, "--column_id", type=int, default=2, help="Column of the dataset csv file with example ids."
help="if validation dataset is not provided, fraction of train dataset " )
"to use as validation dataset.") train_parser.add_argument(
"--skip_first_row", action="store_true", help="Skip the first row of the csv file (headers)."
train_parser.add_argument('--output', type=str, default='./', )
help='path to saved the trained model.')
train_parser.add_argument("--validation_data", type=str, default="", help="path to validation dataset.")
train_parser.add_argument('--task', type=str, default='text_classification', train_parser.add_argument(
help='Task to train the model on.') "--validation_split",
train_parser.add_argument('--model', type=str, default='bert-base-uncased', type=float,
help='Model\'s name or path to stored model.') default=0.1,
train_parser.add_argument('--train_batch_size', type=int, default=32, help="if validation dataset is not provided, fraction of train dataset " "to use as validation dataset.",
help='Batch size for training.') )
train_parser.add_argument('--valid_batch_size', type=int, default=64,
help='Batch size for validation.') train_parser.add_argument("--output", type=str, default="./", help="path to saved the trained model.")
train_parser.add_argument('--learning_rate', type=float, default=3e-5,
help="Learning rate.") train_parser.add_argument(
train_parser.add_argument('--adam_epsilon', type=float, default=1e-08, "--task", type=str, default="text_classification", help="Task to train the model on."
help="Epsilon for Adam optimizer.") )
train_parser.add_argument(
"--model", type=str, default="bert-base-uncased", help="Model's name or path to stored model."
)
train_parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training.")
train_parser.add_argument("--valid_batch_size", type=int, default=64, help="Batch size for validation.")
train_parser.add_argument("--learning_rate", type=float, default=3e-5, help="Learning rate.")
train_parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon for Adam optimizer.")
train_parser.set_defaults(func=train_command_factory) train_parser.set_defaults(func=train_command_factory)
def __init__(self, args: Namespace): def __init__(self, args: Namespace):
self.logger = getLogger('transformers-cli/training') self.logger = getLogger("transformers-cli/training")
self.framework = 'tf' if is_tf_available() else 'torch' self.framework = "tf" if is_tf_available() else "torch"
os.makedirs(args.output, exist_ok=True) os.makedirs(args.output, exist_ok=True)
assert os.path.isdir(args.output) assert os.path.isdir(args.output)
...@@ -81,28 +91,32 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -81,28 +91,32 @@ class TrainCommand(BaseTransformersCLICommand):
self.column_text = args.column_text self.column_text = args.column_text
self.column_id = args.column_id self.column_id = args.column_id
self.logger.info('Loading {} pipeline for {}'.format(args.task, args.model)) self.logger.info("Loading {} pipeline for {}".format(args.task, args.model))
if args.task == 'text_classification': if args.task == "text_classification":
self.pipeline = TextClassificationPipeline.from_pretrained(args.model) self.pipeline = TextClassificationPipeline.from_pretrained(args.model)
elif args.task == 'token_classification': elif args.task == "token_classification":
raise NotImplementedError raise NotImplementedError
elif args.task == 'question_answering': elif args.task == "question_answering":
raise NotImplementedError raise NotImplementedError
self.logger.info('Loading dataset from {}'.format(args.train_data)) self.logger.info("Loading dataset from {}".format(args.train_data))
self.train_dataset = Processor.create_from_csv(args.train_data, self.train_dataset = Processor.create_from_csv(
column_label=args.column_label, args.train_data,
column_text=args.column_text, column_label=args.column_label,
column_id=args.column_id, column_text=args.column_text,
skip_first_row=args.skip_first_row) column_id=args.column_id,
skip_first_row=args.skip_first_row,
)
self.valid_dataset = None self.valid_dataset = None
if args.validation_data: if args.validation_data:
self.logger.info('Loading validation dataset from {}'.format(args.validation_data)) self.logger.info("Loading validation dataset from {}".format(args.validation_data))
self.valid_dataset = Processor.create_from_csv(args.validation_data, self.valid_dataset = Processor.create_from_csv(
column_label=args.column_label, args.validation_data,
column_text=args.column_text, column_label=args.column_label,
column_id=args.column_id, column_text=args.column_text,
skip_first_row=args.skip_first_row) column_id=args.column_id,
skip_first_row=args.skip_first_row,
)
self.validation_split = args.validation_split self.validation_split = args.validation_split
self.train_batch_size = args.train_batch_size self.train_batch_size = args.train_batch_size
...@@ -111,7 +125,7 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -111,7 +125,7 @@ class TrainCommand(BaseTransformersCLICommand):
self.adam_epsilon = args.adam_epsilon self.adam_epsilon = args.adam_epsilon
def run(self): def run(self):
if self.framework == 'tf': if self.framework == "tf":
return self.run_tf() return self.run_tf()
return self.run_torch() return self.run_torch()
...@@ -119,13 +133,15 @@ class TrainCommand(BaseTransformersCLICommand): ...@@ -119,13 +133,15 @@ class TrainCommand(BaseTransformersCLICommand):
raise NotImplementedError raise NotImplementedError
def run_tf(self): def run_tf(self):
self.pipeline.fit(self.train_dataset, self.pipeline.fit(
validation_data=self.valid_dataset, self.train_dataset,
validation_split=self.validation_split, validation_data=self.valid_dataset,
learning_rate=self.learning_rate, validation_split=self.validation_split,
adam_epsilon=self.adam_epsilon, learning_rate=self.learning_rate,
train_batch_size=self.train_batch_size, adam_epsilon=self.adam_epsilon,
valid_batch_size=self.valid_batch_size) train_batch_size=self.train_batch_size,
valid_batch_size=self.valid_batch_size,
)
# Save trained pipeline # Save trained pipeline
self.pipeline.save_pretrained(self.output) self.pipeline.save_pretrained(self.output)
...@@ -9,28 +9,31 @@ from transformers.hf_api import HfApi, HfFolder, HTTPError ...@@ -9,28 +9,31 @@ from transformers.hf_api import HfApi, HfFolder, HTTPError
class UserCommands(BaseTransformersCLICommand): class UserCommands(BaseTransformersCLICommand):
@staticmethod @staticmethod
def register_subcommand(parser: ArgumentParser): def register_subcommand(parser: ArgumentParser):
login_parser = parser.add_parser('login') login_parser = parser.add_parser("login")
login_parser.set_defaults(func=lambda args: LoginCommand(args)) login_parser.set_defaults(func=lambda args: LoginCommand(args))
whoami_parser = parser.add_parser('whoami') whoami_parser = parser.add_parser("whoami")
whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args))
logout_parser = parser.add_parser('logout') logout_parser = parser.add_parser("logout")
logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) logout_parser.set_defaults(func=lambda args: LogoutCommand(args))
list_parser = parser.add_parser('ls') list_parser = parser.add_parser("ls")
list_parser.set_defaults(func=lambda args: ListObjsCommand(args)) list_parser.set_defaults(func=lambda args: ListObjsCommand(args))
# upload # upload
upload_parser = parser.add_parser('upload') upload_parser = parser.add_parser("upload")
upload_parser.add_argument('path', type=str, help='Local path of the folder or individual file to upload.') upload_parser.add_argument("path", type=str, help="Local path of the folder or individual file to upload.")
upload_parser.add_argument('--filename', type=str, default=None, help='Optional: override individual object filename on S3.') upload_parser.add_argument(
"--filename", type=str, default=None, help="Optional: override individual object filename on S3."
)
upload_parser.set_defaults(func=lambda args: UploadCommand(args)) upload_parser.set_defaults(func=lambda args: UploadCommand(args))
class ANSI: class ANSI:
""" """
Helper for en.wikipedia.org/wiki/ANSI_escape_code Helper for en.wikipedia.org/wiki/ANSI_escape_code
""" """
_bold = u"\u001b[1m" _bold = u"\u001b[1m"
_reset = u"\u001b[0m" _reset = u"\u001b[0m"
@classmethod @classmethod
def bold(cls, s): def bold(cls, s):
return "{}{}{}".format(cls._bold, s, cls._reset) return "{}{}{}".format(cls._bold, s, cls._reset)
...@@ -44,14 +47,16 @@ class BaseUserCommand: ...@@ -44,14 +47,16 @@ class BaseUserCommand:
class LoginCommand(BaseUserCommand): class LoginCommand(BaseUserCommand):
def run(self): def run(self):
print(""" print(
"""
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_| _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_| _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _| _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_| _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
""") """
)
username = input("Username: ") username = input("Username: ")
password = getpass() password = getpass()
try: try:
...@@ -101,16 +106,10 @@ class ListObjsCommand(BaseUserCommand): ...@@ -101,16 +106,10 @@ class ListObjsCommand(BaseUserCommand):
col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)] col_widths = [max(len(str(x)) for x in col) for col in zip(*rows, headers)]
row_format = ("{{:{}}} " * len(headers)).format(*col_widths) row_format = ("{{:{}}} " * len(headers)).format(*col_widths)
lines = [] lines = []
lines.append( lines.append(row_format.format(*headers))
row_format.format(*headers) lines.append(row_format.format(*["-" * w for w in col_widths]))
)
lines.append(
row_format.format(*["-" * w for w in col_widths])
)
for row in rows: for row in rows:
lines.append( lines.append(row_format.format(*row))
row_format.format(*row)
)
return "\n".join(lines) return "\n".join(lines)
def run(self): def run(self):
...@@ -126,15 +125,8 @@ class ListObjsCommand(BaseUserCommand): ...@@ -126,15 +125,8 @@ class ListObjsCommand(BaseUserCommand):
if len(objs) == 0: if len(objs) == 0:
print("No shared file yet") print("No shared file yet")
exit() exit()
rows = [ [ rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
obj.filename, print(self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
obj.LastModified,
obj.ETag,
obj.Size
] for obj in objs ]
print(
self.tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])
)
class UploadCommand(BaseUserCommand): class UploadCommand(BaseUserCommand):
...@@ -143,13 +135,7 @@ class UploadCommand(BaseUserCommand): ...@@ -143,13 +135,7 @@ class UploadCommand(BaseUserCommand):
Recursively list all files in a folder. Recursively list all files in a folder.
""" """
entries: List[os.DirEntry] = list(os.scandir(rel_path)) entries: List[os.DirEntry] = list(os.scandir(rel_path))
files = [ files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # filepath # filename
(
os.path.join(os.getcwd(), f.path), # filepath
f.path # filename
)
for f in entries if f.is_file()
]
for f in entries: for f in entries:
if f.is_dir(): if f.is_dir():
files += self.walk_dir(f.path) files += self.walk_dir(f.path)
...@@ -173,22 +159,14 @@ class UploadCommand(BaseUserCommand): ...@@ -173,22 +159,14 @@ class UploadCommand(BaseUserCommand):
raise ValueError("Not a valid file or directory: {}".format(local_path)) raise ValueError("Not a valid file or directory: {}".format(local_path))
for filepath, filename in files: for filepath, filename in files:
print( print("About to upload file {} to S3 under filename {}".format(ANSI.bold(filepath), ANSI.bold(filename)))
"About to upload file {} to S3 under filename {}".format(
ANSI.bold(filepath), ANSI.bold(filename)
)
)
choice = input("Proceed? [Y/n] ").lower() choice = input("Proceed? [Y/n] ").lower()
if not(choice == "" or choice == "y" or choice == "yes"): if not (choice == "" or choice == "y" or choice == "yes"):
print("Abort") print("Abort")
exit() exit()
print( print(ANSI.bold("Uploading... This might take a while if files are large"))
ANSI.bold("Uploading... This might take a while if files are large")
)
for filepath, filename in files: for filepath, filename in files:
access_url = self._api.presign_and_upload( access_url = self._api.presign_and_upload(token=token, filename=filename, filepath=filepath)
token=token, filename=filename, filepath=filepath
)
print("Your file now lives at:") print("Your file now lives at:")
print(access_url) print(access_url)
...@@ -18,16 +18,17 @@ ...@@ -18,16 +18,17 @@
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'albert-base-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json", "albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
'albert-large-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json", "albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
'albert-xlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json", "albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json",
'albert-xxlarge-v1': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json", "albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json",
'albert-base-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json", "albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json",
'albert-large-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json", "albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
'albert-xlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json", "albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json",
'albert-xxlarge-v2': "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json", "albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json",
} }
class AlbertConfig(PretrainedConfig): class AlbertConfig(PretrainedConfig):
"""Configuration for `AlbertModel`. """Configuration for `AlbertModel`.
...@@ -36,22 +37,25 @@ class AlbertConfig(PretrainedConfig): ...@@ -36,22 +37,25 @@ class AlbertConfig(PretrainedConfig):
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
def __init__(self, def __init__(
vocab_size=30000, self,
embedding_size=128, vocab_size=30000,
hidden_size=4096, embedding_size=128,
num_hidden_layers=12, hidden_size=4096,
num_hidden_groups=1, num_hidden_layers=12,
num_attention_heads=64, num_hidden_groups=1,
intermediate_size=16384, num_attention_heads=64,
inner_group_num=1, intermediate_size=16384,
hidden_act="gelu_new", inner_group_num=1,
hidden_dropout_prob=0, hidden_act="gelu_new",
attention_probs_dropout_prob=0, hidden_dropout_prob=0,
max_position_embeddings=512, attention_probs_dropout_prob=0,
type_vocab_size=2, max_position_embeddings=512,
initializer_range=0.02, type_vocab_size=2,
layer_norm_eps=1e-12, **kwargs): initializer_range=0.02,
layer_norm_eps=1e-12,
**kwargs
):
"""Constructs AlbertConfig. """Constructs AlbertConfig.
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment