Unverified Commit 54abc67a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2255 from aaugustin/implement-best-practices

Implement some Python best practices
parents 645713e2 c11b3e29
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -21,21 +21,27 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -21,21 +21,27 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import argparse import argparse
import logging import logging
import torch
import numpy as np import numpy as np
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer from transformers import (
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer CTRLLMHeadModel,
from transformers import XLNetLMHeadModel, XLNetTokenizer CTRLTokenizer,
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer GPT2LMHeadModel,
from transformers import CTRLLMHeadModel, CTRLTokenizer GPT2Tokenizer,
from transformers import XLMWithLMHeadModel, XLMTokenizer OpenAIGPTLMHeadModel,
OpenAIGPTTokenizer,
TransfoXLLMHeadModel,
TransfoXLTokenizer,
XLMTokenizer,
XLMWithLMHeadModel,
XLNetLMHeadModel,
XLNetTokenizer,
)
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -71,6 +77,7 @@ def set_seed(args): ...@@ -71,6 +77,7 @@ def set_seed(args):
if args.n_gpu > 0: if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
# #
# Functions to prepare models' input # Functions to prepare models' input
# #
...@@ -78,15 +85,11 @@ def set_seed(args): ...@@ -78,15 +85,11 @@ def set_seed(args):
def prepare_ctrl_input(args, _, tokenizer, prompt_text): def prepare_ctrl_input(args, _, tokenizer, prompt_text):
if args.temperature > 0.7: if args.temperature > 0.7:
logger.info( logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
"CTRL typically works better with lower temperatures (and lower top_k)."
)
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
logger.info( logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
"WARNING! You are not starting your generation from a control code so you won't get good results"
)
return prompt_text return prompt_text
...@@ -102,11 +105,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text): ...@@ -102,11 +105,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
else: else:
language = None language = None
while language not in available_languages: while language not in available_languages:
language = input( language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
"Using XLM. Select language in "
+ str(list(available_languages))
+ " >>> "
)
# kwargs["language"] = tokenizer.lang2id[language] # kwargs["language"] = tokenizer.lang2id[language]
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
...@@ -148,17 +147,34 @@ def adjust_length_to_model(length, max_sequence_length): ...@@ -148,17 +147,34 @@ def adjust_length_to_model(length, max_sequence_length):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_type", default=None, type=str, required=True, parser.add_argument(
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) "--model_type",
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, default=None,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys())) type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument("--prompt", type=str, default="") parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--length", type=int, default=20) parser.add_argument("--length", type=int, default=20)
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 1.0 has no effect, lower tend toward greedy sampling") parser.add_argument(
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2") "--temperature",
type=float,
default=1.0,
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
)
parser.add_argument(
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
)
parser.add_argument("--k", type=int, default=0) parser.add_argument("--k", type=int, default=0)
parser.add_argument("--p", type=float, default=0.9) parser.add_argument("--p", type=float, default=0.9)
...@@ -169,9 +185,7 @@ def main(): ...@@ -169,9 +185,7 @@ def main():
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
args = parser.parse_args() args = parser.parse_args()
args.device = torch.device( args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
)
args.n_gpu = torch.cuda.device_count() args.n_gpu = torch.cuda.device_count()
set_seed(args) set_seed(args)
...@@ -181,17 +195,13 @@ def main(): ...@@ -181,17 +195,13 @@ def main():
args.model_type = args.model_type.lower() args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type] model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
except KeyError: except KeyError:
raise KeyError( raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
)
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path) model = model_class.from_pretrained(args.model_name_or_path)
model.to(args.device) model.to(args.device)
args.length = adjust_length_to_model( args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings)
args.length, max_sequence_length=model.config.max_position_embeddings
)
logger.info(args) logger.info(args)
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
...@@ -201,7 +211,7 @@ def main(): ...@@ -201,7 +211,7 @@ def main():
if requires_preprocessing: if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
prompt_text = prepare_input(args, model, tokenizer, prompt_text) prompt_text = prepare_input(args, model, tokenizer, prompt_text)
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt') encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
output_sequences = model.generate( output_sequences = model.generate(
input_ids=encoded_prompt, input_ids=encoded_prompt,
...@@ -215,7 +225,7 @@ def main(): ...@@ -215,7 +225,7 @@ def main():
# Batch size == 1. to add more examples please use num_return_sequences > 1 # Batch size == 1. to add more examples please use num_return_sequences > 1
generated_sequence = output_sequences[0].tolist() generated_sequence = output_sequences[0].tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
text = text[: t.find(args.stop_token) if args.stop_token else None] text = text[: text.find(args.stop_token) if args.stop_token else None]
print(text) print(text)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment