Commit a8e3336a authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[examples] Use AutoModels in more examples

parent ec6766a3
......@@ -31,6 +31,7 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from transformers import (
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
WEIGHTS_NAME,
AdamW,
AutoConfig,
......@@ -38,7 +39,6 @@ from transformers import (
AutoTokenizer,
get_linear_schedule_with_warmup,
)
from transformers.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
......@@ -52,6 +52,7 @@ logger = logging.getLogger(__name__)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), ())
TOKENIZER_ARGS = ["do_lower_case", "strip_accents", "keep_accents", "use_fast"]
......
......@@ -13,16 +13,11 @@ from seqeval import metrics
from transformers import (
TF2_WEIGHTS_NAME,
BertConfig,
BertTokenizer,
DistilBertConfig,
DistilBertTokenizer,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
AutoConfig,
AutoTokenizer,
GradientAccumulator,
RobertaConfig,
RobertaTokenizer,
TFBertForTokenClassification,
TFDistilBertForTokenClassification,
TFRobertaForTokenClassification,
TFAutoModelForTokenClassification,
create_optimizer,
)
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
......@@ -34,22 +29,17 @@ except ImportError:
from fastprogress.fastprogress import master_bar, progress_bar
ALL_MODELS = sum(
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)), ()
)
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
MODEL_CLASSES = {
"bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
"roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
"distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer),
}
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
flags.DEFINE_string(
"data_dir", None, "The input data dir. Should contain the .conll files (or other data files) " "for the task."
)
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
flags.DEFINE_string("model_type", None, "Model type selected in the list: " + ", ".join(MODEL_TYPES))
flags.DEFINE_string(
"model_name_or_path",
......@@ -509,8 +499,7 @@ def main(_):
labels = get_labels(args["labels"])
num_labels = len(labels) + 1
pad_token_label_id = 0
config_class, model_class, tokenizer_class = MODEL_CLASSES[args["model_type"]]
config = config_class.from_pretrained(
config = AutoConfig.from_pretrained(
args["config_name"] if args["config_name"] else args["model_name_or_path"],
num_labels=num_labels,
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
......@@ -520,14 +509,14 @@ def main(_):
# Training
if args["do_train"]:
tokenizer = tokenizer_class.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
args["tokenizer_name"] if args["tokenizer_name"] else args["model_name_or_path"],
do_lower_case=args["do_lower_case"],
cache_dir=args["cache_dir"] if args["cache_dir"] else None,
)
with strategy.scope():
model = model_class.from_pretrained(
model = TFAutoModelForTokenClassification.from_pretrained(
args["model_name_or_path"],
from_pt=bool(".bin" in args["model_name_or_path"]),
config=config,
......@@ -562,7 +551,7 @@ def main(_):
# Evaluation
if args["do_eval"]:
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
checkpoints = []
results = []
......@@ -584,7 +573,7 @@ def main(_):
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
with strategy.scope():
model = model_class.from_pretrained(checkpoint)
model = TFAutoModelForTokenClassification.from_pretrained(checkpoint)
y_true, y_pred, eval_loss = evaluate(
args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev"
......@@ -611,8 +600,8 @@ def main(_):
writer.write("\n")
if args["do_predict"]:
tokenizer = tokenizer_class.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
model = model_class.from_pretrained(args["output_dir"])
tokenizer = AutoTokenizer.from_pretrained(args["output_dir"], do_lower_case=args["do_lower_case"])
model = TFAutoModelForTokenClassification.from_pretrained(args["output_dir"])
eval_batch_size = args["per_device_eval_batch_size"] * args["n_device"]
predict_dataset, _ = load_and_cache_examples(
args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test"
......
......@@ -30,32 +30,12 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from transformers import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
WEIGHTS_NAME,
AdamW,
AlbertConfig,
AlbertForSequenceClassification,
AlbertTokenizer,
BertConfig,
BertForSequenceClassification,
BertTokenizer,
DistilBertConfig,
DistilBertForSequenceClassification,
DistilBertTokenizer,
FlaubertConfig,
FlaubertForSequenceClassification,
FlaubertTokenizer,
RobertaConfig,
RobertaForSequenceClassification,
RobertaTokenizer,
XLMConfig,
XLMForSequenceClassification,
XLMRobertaConfig,
XLMRobertaForSequenceClassification,
XLMRobertaTokenizer,
XLMTokenizer,
XLNetConfig,
XLNetForSequenceClassification,
XLNetTokenizer,
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
get_linear_schedule_with_warmup,
)
from transformers import glue_compute_metrics as compute_metrics
......@@ -72,33 +52,10 @@ except ImportError:
logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(
tuple(conf.pretrained_config_archive_map.keys())
for conf in (
BertConfig,
XLNetConfig,
XLMConfig,
RobertaConfig,
DistilBertConfig,
AlbertConfig,
XLMRobertaConfig,
FlaubertConfig,
)
),
(),
)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
MODEL_CLASSES = {
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
"xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
"roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
"xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
"flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer),
}
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
def set_seed(args):
......@@ -442,7 +399,7 @@ def main():
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
)
parser.add_argument(
"--model_name_or_path",
......@@ -622,19 +579,18 @@ def main():
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(
config = AutoConfig.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
num_labels=num_labels,
finetuning_task=args.task_name,
cache_dir=args.cache_dir if args.cache_dir else None,
)
tokenizer = tokenizer_class.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
)
model = model_class.from_pretrained(
model = AutoModelForSequenceClassification.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
......@@ -673,14 +629,14 @@ def main():
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
model = AutoModelForSequenceClassification.from_pretrained(args.output_dir)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
model.to(args.device)
# Evaluation
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(
......@@ -692,7 +648,7 @@ def main():
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
model = model_class.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
......
......@@ -38,28 +38,15 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from transformers import (
CONFIG_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
WEIGHTS_NAME,
AdamW,
BertConfig,
BertForMaskedLM,
BertTokenizer,
CamembertConfig,
CamembertForMaskedLM,
CamembertTokenizer,
DistilBertConfig,
DistilBertForMaskedLM,
DistilBertTokenizer,
GPT2Config,
GPT2LMHeadModel,
GPT2Tokenizer,
OpenAIGPTConfig,
OpenAIGPTLMHeadModel,
OpenAIGPTTokenizer,
AutoConfig,
AutoModelWithLMHead,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
RobertaConfig,
RobertaForMaskedLM,
RobertaTokenizer,
get_linear_schedule_with_warmup,
)
......@@ -73,14 +60,8 @@ except ImportError:
logger = logging.getLogger(__name__)
MODEL_CLASSES = {
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
"openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
"camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
}
MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
class TextDataset(Dataset):
......@@ -693,23 +674,21 @@ def main():
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
if args.config_name:
config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir)
config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)
elif args.model_name_or_path:
config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
else:
config = config_class()
config = CONFIG_MAPPING[args.model_type]()
if args.tokenizer_name:
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
elif args.model_name_or_path:
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
else:
raise ValueError(
"You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
"and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
"and load it from here, using --tokenizer_name".format(AutoTokenizer.__name__)
)
if args.block_size <= 0:
......@@ -719,7 +698,7 @@ def main():
args.block_size = min(args.block_size, tokenizer.max_len)
if args.model_name_or_path:
model = model_class.from_pretrained(
model = AutoModelWithLMHead.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
......@@ -727,7 +706,7 @@ def main():
)
else:
logger.info("Training new model from scratch")
model = model_class(config=config)
model = AutoModelWithLMHead(config=config)
model.to(args.device)
......@@ -768,8 +747,8 @@ def main():
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
model = AutoModelWithLMHead.from_pretrained(args.output_dir)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
model.to(args.device)
# Evaluation
......@@ -786,7 +765,7 @@ def main():
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
model = model_class.from_pretrained(checkpoint)
model = AutoModelWithLMHead.from_pretrained(checkpoint)
model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
......
......@@ -30,29 +30,12 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from transformers import (
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
WEIGHTS_NAME,
AdamW,
AlbertConfig,
AlbertForQuestionAnswering,
AlbertTokenizer,
BertConfig,
BertForQuestionAnswering,
BertTokenizer,
CamembertConfig,
CamembertForQuestionAnswering,
CamembertTokenizer,
DistilBertConfig,
DistilBertForQuestionAnswering,
DistilBertTokenizer,
RobertaConfig,
RobertaForQuestionAnswering,
RobertaTokenizer,
XLMConfig,
XLMForQuestionAnswering,
XLMTokenizer,
XLNetConfig,
XLNetForQuestionAnswering,
XLNetTokenizer,
AutoConfig,
AutoModelForQuestionAnswering,
AutoTokenizer,
get_linear_schedule_with_warmup,
squad_convert_examples_to_features,
)
......@@ -72,23 +55,10 @@ except ImportError:
logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(
tuple(conf.pretrained_config_archive_map.keys())
for conf in (BertConfig, CamembertConfig, RobertaConfig, XLNetConfig, XLMConfig)
),
(),
)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
MODEL_CLASSES = {
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
"camembert": (CamembertConfig, CamembertForQuestionAnswering, CamembertTokenizer),
"roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
"albert": (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer),
}
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
def set_seed(args):
......@@ -513,7 +483,7 @@ def main():
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
)
parser.add_argument(
"--model_name_or_path",
......@@ -757,17 +727,16 @@ def main():
torch.distributed.barrier()
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(
config = AutoConfig.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
cache_dir=args.cache_dir if args.cache_dir else None,
)
tokenizer = tokenizer_class.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
)
model = model_class.from_pretrained(
model = AutoModelForQuestionAnswering.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
......@@ -817,8 +786,8 @@ def main():
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir) # , force_download=True)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir) # , force_download=True)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
......@@ -842,7 +811,7 @@ def main():
for checkpoint in checkpoints:
# Reload the model
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint) # , force_download=True)
model = AutoModelForQuestionAnswering.from_pretrained(checkpoint) # , force_download=True)
model.to(args.device)
# Evaluate
......
......@@ -158,6 +158,12 @@ if is_torch_available():
AutoModelWithLMHead,
AutoModelForTokenClassification,
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
MODEL_WITH_LM_HEAD_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
)
from .modeling_bert import (
......@@ -317,6 +323,12 @@ if is_tf_available():
TFAutoModelWithLMHead,
TFAutoModelForTokenClassification,
TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
TF_MODEL_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_WITH_LM_HEAD_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
)
from .modeling_tf_bert import (
......
......@@ -28,20 +28,12 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from transformers import (
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
WEIGHTS_NAME,
AdamW,
BertConfig,
BertForQuestionAnswering,
BertTokenizer,
DistilBertConfig,
DistilBertForQuestionAnswering,
DistilBertTokenizer,
XLMConfig,
XLMForQuestionAnswering,
XLMTokenizer,
XLNetConfig,
XLNetForQuestionAnswering,
XLNetTokenizer,
AutoConfig,
AutoModelForQuestionAnswering,
AutoTokenizer,
get_linear_schedule_with_warmup,
)
from utils_squad import (
......@@ -68,16 +60,10 @@ except ImportError:
logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()
)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
MODEL_CLASSES = {
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
}
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
def set_seed(args):
......@@ -418,7 +404,7 @@ def main():
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
help="Model type selected in the list: " + ", ".join(MODEL_TYPES),
)
parser.add_argument(
"--model_name_or_path",
......@@ -626,17 +612,16 @@ def main():
# download model & vocab
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(
config = AutoConfig.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
cache_dir=args.cache_dir if args.cache_dir else None,
)
tokenizer = tokenizer_class.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None,
)
model = model_class.from_pretrained(
model = AutoModelForQuestionAnswering.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
......@@ -687,8 +672,8 @@ def main():
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model = AutoModelForQuestionAnswering.from_pretrained(args.output_dir)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
......@@ -706,7 +691,7 @@ def main():
for checkpoint in checkpoints:
# Reload the model
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint)
model = AutoModelForQuestionAnswering.from_pretrained(checkpoint)
model.to(args.device)
# Evaluate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment