Unverified Commit 54abc67a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2255 from aaugustin/implement-best-practices

Implement some Python best practices
parents 645713e2 c11b3e29
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -21,21 +21,27 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -21,21 +21,27 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import argparse import argparse
import logging import logging
import torch
import numpy as np import numpy as np
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer from transformers import (
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer CTRLLMHeadModel,
from transformers import XLNetLMHeadModel, XLNetTokenizer CTRLTokenizer,
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer GPT2LMHeadModel,
from transformers import CTRLLMHeadModel, CTRLTokenizer GPT2Tokenizer,
from transformers import XLMWithLMHeadModel, XLMTokenizer OpenAIGPTLMHeadModel,
OpenAIGPTTokenizer,
TransfoXLLMHeadModel,
TransfoXLTokenizer,
XLMTokenizer,
XLMWithLMHeadModel,
XLNetLMHeadModel,
XLNetTokenizer,
)
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO,
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -71,6 +77,7 @@ def set_seed(args): ...@@ -71,6 +77,7 @@ def set_seed(args):
if args.n_gpu > 0: if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
# #
# Functions to prepare models' input # Functions to prepare models' input
# #
...@@ -78,15 +85,11 @@ def set_seed(args): ...@@ -78,15 +85,11 @@ def set_seed(args):
def prepare_ctrl_input(args, _, tokenizer, prompt_text): def prepare_ctrl_input(args, _, tokenizer, prompt_text):
if args.temperature > 0.7: if args.temperature > 0.7:
logger.info( logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
"CTRL typically works better with lower temperatures (and lower top_k)."
)
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
logger.info( logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
"WARNING! You are not starting your generation from a control code so you won't get good results"
)
return prompt_text return prompt_text
...@@ -102,11 +105,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text): ...@@ -102,11 +105,7 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
else: else:
language = None language = None
while language not in available_languages: while language not in available_languages:
language = input( language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
"Using XLM. Select language in "
+ str(list(available_languages))
+ " >>> "
)
# kwargs["language"] = tokenizer.lang2id[language] # kwargs["language"] = tokenizer.lang2id[language]
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
...@@ -148,17 +147,34 @@ def adjust_length_to_model(length, max_sequence_length): ...@@ -148,17 +147,34 @@ def adjust_length_to_model(length, max_sequence_length):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_type", default=None, type=str, required=True, parser.add_argument(
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) "--model_type",
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, default=None,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys())) type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument("--prompt", type=str, default="") parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--length", type=int, default=20) parser.add_argument("--length", type=int, default=20)
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 1.0 has no effect, lower tend toward greedy sampling") parser.add_argument(
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2") "--temperature",
type=float,
default=1.0,
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
)
parser.add_argument(
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
)
parser.add_argument("--k", type=int, default=0) parser.add_argument("--k", type=int, default=0)
parser.add_argument("--p", type=float, default=0.9) parser.add_argument("--p", type=float, default=0.9)
...@@ -169,9 +185,7 @@ def main(): ...@@ -169,9 +185,7 @@ def main():
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
args = parser.parse_args() args = parser.parse_args()
args.device = torch.device( args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
)
args.n_gpu = torch.cuda.device_count() args.n_gpu = torch.cuda.device_count()
set_seed(args) set_seed(args)
...@@ -181,17 +195,13 @@ def main(): ...@@ -181,17 +195,13 @@ def main():
args.model_type = args.model_type.lower() args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type] model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
except KeyError: except KeyError:
raise KeyError( raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
)
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path) model = model_class.from_pretrained(args.model_name_or_path)
model.to(args.device) model.to(args.device)
args.length = adjust_length_to_model( args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings)
args.length, max_sequence_length=model.config.max_position_embeddings
)
logger.info(args) logger.info(args)
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
...@@ -201,7 +211,7 @@ def main(): ...@@ -201,7 +211,7 @@ def main():
if requires_preprocessing: if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
prompt_text = prepare_input(args, model, tokenizer, prompt_text) prompt_text = prepare_input(args, model, tokenizer, prompt_text)
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt') encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
output_sequences = model.generate( output_sequences = model.generate(
input_ids=encoded_prompt, input_ids=encoded_prompt,
...@@ -212,10 +222,10 @@ def main(): ...@@ -212,10 +222,10 @@ def main():
repetition_penalty=args.repetition_penalty, repetition_penalty=args.repetition_penalty,
) )
# Batch size == 1. to add more examples please use num_return_sequences > 1 # Batch size == 1. to add more examples please use num_return_sequences > 1
generated_sequence = output_sequences[0].tolist() generated_sequence = output_sequences[0].tolist()
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
text = text[: t.find(args.stop_token) if args.stop_token else None] text = text[: text.find(args.stop_token) if args.stop_token else None]
print(text) print(text)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -14,9 +14,7 @@ ...@@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" BertAbs configuration """ """ BertAbs configuration """
import json
import logging import logging
import sys
from transformers import PretrainedConfig from transformers import PretrainedConfig
......
This diff is collapsed.
This diff is collapsed.
from collections import deque
import os import os
from collections import deque
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -68,9 +68,7 @@ def process_story(raw_story): ...@@ -68,9 +68,7 @@ def process_story(raw_story):
Raises: Raises:
IndexError: If the stoy is empty or contains no highlights. IndexError: If the stoy is empty or contains no highlights.
""" """
nonempty_lines = list( nonempty_lines = list(filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")]))
filter(lambda x: len(x) != 0, [line.strip() for line in raw_story.split("\n")])
)
# for some unknown reason some lines miss a period, add it # for some unknown reason some lines miss a period, add it
nonempty_lines = [_add_missing_period(line) for line in nonempty_lines] nonempty_lines = [_add_missing_period(line) for line in nonempty_lines]
...@@ -135,13 +133,9 @@ def encode_for_summarization(story_lines, summary_lines, tokenizer): ...@@ -135,13 +133,9 @@ def encode_for_summarization(story_lines, summary_lines, tokenizer):
sentences. sentences.
""" """
story_lines_token_ids = [tokenizer.encode(line) for line in story_lines] story_lines_token_ids = [tokenizer.encode(line) for line in story_lines]
story_token_ids = [ story_token_ids = [token for sentence in story_lines_token_ids for token in sentence]
token for sentence in story_lines_token_ids for token in sentence
]
summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines] summary_lines_token_ids = [tokenizer.encode(line) for line in summary_lines]
summary_token_ids = [ summary_token_ids = [token for sentence in summary_lines_token_ids for token in sentence]
token for sentence in summary_lines_token_ids for token in sentence
]
return story_token_ids, summary_token_ids return story_token_ids, summary_token_ids
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment