Commit c2c2ca0f authored by LysandreJik's avatar LysandreJik
Browse files

Added XLM to run_generation, with prompt language selection.

parent 1569610f
...@@ -26,12 +26,13 @@ import torch ...@@ -26,12 +26,13 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig
from transformers import GPT2LMHeadModel, GPT2Tokenizer from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer from transformers import XLNetLMHeadModel, XLNetTokenizer
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
...@@ -41,13 +42,14 @@ logger = logging.getLogger(__name__) ...@@ -41,13 +42,14 @@ logger = logging.getLogger(__name__)
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ()) ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig)), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer), 'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'xlnet': (XLNetLMHeadModel, XLNetTokenizer), 'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer), 'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
'xlm': (XLMWithLMHeadModel, XLMTokenizer),
} }
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
...@@ -103,7 +105,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf') ...@@ -103,7 +105,8 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
return logits return logits
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'): def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False,
xlm_lang=None, device='cpu'):
context = torch.tensor(context, dtype=torch.long, device=device) context = torch.tensor(context, dtype=torch.long, device=device)
context = context.unsqueeze(0).repeat(num_samples, 1) context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context generated = context
...@@ -121,6 +124,9 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k= ...@@ -121,6 +124,9 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
target_mapping[0, 0, -1] = 1.0 # predict last token target_mapping[0, 0, -1] = 1.0 # predict last token
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping} inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
if xlm_lang is not None:
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1]).view(1, -1)
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
next_token_logits = outputs[0][0, -1, :] / temperature next_token_logits = outputs[0][0, -1, :] / temperature
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
...@@ -137,6 +143,7 @@ def main(): ...@@ -137,6 +143,7 @@ def main():
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--prompt", type=str, default="") parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--padding_text", type=str, default="") parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--length", type=int, default=20) parser.add_argument("--length", type=int, default=20)
parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=0) parser.add_argument("--top_k", type=int, default=0)
...@@ -168,6 +175,17 @@ def main(): ...@@ -168,6 +175,17 @@ def main():
print(args) print(args)
while True: while True:
xlm_lang = None
# XLM Language usage detailed in the issues #1414
if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id'):
if args.xlm_lang:
language = args.xlm_lang
else:
language = None
while language not in tokenizer.lang2id.keys():
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ")
xlm_lang = tokenizer.lang2id[language]
raw_text = args.prompt if args.prompt else input("Model prompt >>> ") raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
if args.model_type in ["transfo-xl", "xlnet"]: if args.model_type in ["transfo-xl", "xlnet"]:
# Models with memory likes to have a long prompt for short inputs. # Models with memory likes to have a long prompt for short inputs.
...@@ -180,11 +198,12 @@ def main(): ...@@ -180,11 +198,12 @@ def main():
temperature=args.temperature, temperature=args.temperature,
top_k=args.top_k, top_k=args.top_k,
top_p=args.top_p, top_p=args.top_p,
device=args.device,
is_xlnet=bool(args.model_type == "xlnet"), is_xlnet=bool(args.model_type == "xlnet"),
xlm_lang=xlm_lang,
device=args.device,
) )
out = out[0, len(context_tokens):].tolist() out = out[0, len(context_tokens):].tolist()
text = tokenizer.decode(out, clean_up_tokenization_spaces=True) text = tokenizer.decode(out, clean_up_tokenization_spaces=True, skip_special_tokens=True)
print(text) print(text)
if args.prompt: if args.prompt:
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment