"src/lib/vscode:/vscode.git/clone" did not exist on "656f8dab05c842f964a02db15e99b580ae4d10f9"
Commit 07bc8efb authored by Rémi Louf's avatar Rémi Louf
Browse files

add greedy decoding and sampling

parent e57d00ee
......@@ -20,14 +20,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import argparse
import logging
from tqdm import trange
import torch
import torch.nn.functional as F
import numpy as np
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer
......@@ -36,22 +32,22 @@ from transformers import CTRLLMHeadModel, CTRLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())
MODEL_CLASSES = {
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
'xlm': (XLMWithLMHeadModel, XLMTokenizer),
"gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
"ctrl": (CTRLLMHeadModel, CTRLTokenizer),
"openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
"xlnet": (XLNetLMHeadModel, XLNetTokenizer),
"transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
"xlm": (XLMWithLMHeadModel, XLMTokenizer),
}
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
......@@ -75,81 +71,78 @@ def set_seed(args):
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
#
# Functions to prepare models' input
#
def prepare_ctrl_input(args, _, tokenizer, prompt_text):
if args.temperature > 0.7:
logger.info(
"CTRL typically works better with lower temperatures (and lower top_k)."
)
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
logger.info(
"WARNING! You are not starting your generation from a control code so you won't get good results"
)
return prompt_text, {}
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size x vocabulary size)
top_k > 0: keep only top k tokens with highest probability (top-k filtering).
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k = min(top_k, logits.size(-1)) # Safety check
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
context = torch.tensor(context, dtype=torch.long, device=device)
context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context
with torch.no_grad():
for _ in trange(length):
inputs = {'input_ids': generated}
if is_xlnet:
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
target_mapping[0, 0, -1] = 1.0 # predict last token
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
if is_xlm_mlm and xlm_mask_token:
# XLM MLM models are direct models (predict same token, not next token)
# => need one additional dummy token in the input (will be masked and guessed)
input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
inputs = {'input_ids': input_ids}
if xlm_lang is not None:
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
# repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for i in range(num_samples):
for _ in set(generated[i].tolist()):
next_token_logits[i, _] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0: # greedy sampling:
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
def prepare_xlm_input(args, model, tokenizer, prompt_text):
kwargs = {"language": None, "mask_token": None}
# Set the language
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
if hasattr(model.config, "lang2id") and use_lang_emb:
available_languages = model.config.lang2id.keys()
if args.xlm_language in available_languages:
language = args.xlm_language
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token), dim=1)
return generated
language = None
while language not in available_languages:
language = input(
"Using XLM. Select language in "
+ str(list(available_languages))
+ " >>> "
)
kwargs["language"] = tokenizer.lang2id[language]
# XLM masked-language modeling (MLM) models need masked token
is_xlm_mlm = "mlm" in args.model_name_or_path
if is_xlm_mlm:
kwargs["mask_token"] = tokenizer.mask_token_id
return prompt_text, kwargs
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
return prompt_text, {}
def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
return prompt_text, {}
PREPROCESSING_FUNCTIONS = {
"ctrl": prepare_ctrl_input,
"xlm": prepare_xlm_input,
"xlnet": prepare_xlnet_input,
"transfo-xl": prepare_transfoxl_input,
}
def adjust_length_to_model(length, max_sequence_length):
if length < 0 and max_sequence_length > 0:
length = max_sequence_length
elif 0 < max_sequence_length < length:
length = max_sequence_length # No generation bigger than model size
elif length < 0:
length = MAX_LENGTH # avoid infinite loop
return length
def main():
......@@ -157,104 +150,81 @@ def main():
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--length", type=int, default=20)
parser.add_argument("--num_samples", type=int, default=1)
parser.add_argument("--temperature", type=float, default=1.0,
help="temperature of 0 implies greedy sampling")
parser.add_argument("--repetition_penalty", type=float, default=1.0,
help="primarily useful for CTRL model; in that case, use 1.2")
parser.add_argument("--top_k", type=int, default=0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--stop_token', type=str, default=None,
help="Token at which text generation is stopped")
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 0 implies greedy sampling")
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2")
parser.add_argument("--k", type=int, default=0)
parser.add_argument("--p", type=float, default=0.9)
parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.")
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.device = torch.device(
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
)
args.n_gpu = torch.cuda.device_count()
set_seed(args)
# Initialize the model and tokenizer
try:
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
except KeyError as ke:
raise ke(
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
)
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
model.to(args.device)
model.eval()
if args.length < 0 and model.config.max_position_embeddings > 0:
args.length = model.config.max_position_embeddings
elif 0 < model.config.max_position_embeddings < args.length:
args.length = model.config.max_position_embeddings # No generation bigger than model size
elif args.length < 0:
args.length = MAX_LENGTH # avoid infinite loop
args.length = adjust_length_to_model(
args.length, max_sequence_length=model.config.max_position_embeddings
)
logger.info(args)
if args.model_type in ["ctrl"]:
if args.temperature > 0.7:
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
while True:
xlm_lang = None
# XLM Language usage detailed in the issues #1414
if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id') and hasattr(model.config, 'use_lang_emb') \
and model.config.use_lang_emb:
if args.xlm_lang:
language = args.xlm_lang
else:
language = None
while language not in tokenizer.lang2id.keys():
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ")
xlm_lang = tokenizer.lang2id[language]
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence)
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
if is_xlm_mlm:
xlm_mask_token = tokenizer.mask_token_id
else:
xlm_mask_token = None
raw_text = args.prompt if args.prompt else input("Model prompt >>> ")
if args.model_type in ["transfo-xl", "xlnet"]:
# Models with memory likes to have a long prompt for short inputs.
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
context_tokens = tokenizer.encode(raw_text, add_special_tokens=False)
if args.model_type == "ctrl":
if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()):
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
out = sample_sequence(
model=model,
context=context_tokens,
num_samples=args.num_samples,
prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
# Different models need different input formatting and/or extra arguments
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
model_kwargs = {}
if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
encoded_prompt = torch.tensor(tokenizer.encode(prompt_text, add_special_tokens=False)).unsqueeze(0)
output_sequences = model.decode(
prompt_ids=encoded_prompt,
length=args.length,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
k=args.k,
p=args.p,
repetition_penalty=args.repetition_penalty,
is_xlnet=bool(args.model_type == "xlnet"),
is_xlm_mlm=is_xlm_mlm,
xlm_mask_token=xlm_mask_token,
xlm_lang=xlm_lang,
device=args.device,
**model_kwargs,
)
out = out[:, len(context_tokens):].tolist()
for o in out:
text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
generated_sequence = output_sequences.tolist()[
encoded_prompt.size(1) :
] # adapted to case where num_samples > 1
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
text = text[: text.find(args.stop_token) if args.stop_token else None]
print(text)
if args.prompt:
break
return text
if __name__ == '__main__':
if __name__ == "__main__":
main()
......@@ -18,11 +18,14 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging
import os
import warnings
import torch
from torch import nn
from tqdm import trange
from .modeling_auto import AutoModel, AutoModelWithLMHead
from .modeling_utils import Sampler
logger = logging.getLogger(__name__)
......@@ -117,8 +120,7 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("encoder_")
and not argument.startswith("decoder_")
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
}
kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy()
......@@ -186,51 +188,151 @@ class PreTrainedEncoderDecoder(nn.Module):
Indices of decoder input sequence tokens in the vocabulary.
kwargs: (`optional`) Remaining dictionary of keyword arguments.
"""
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs)
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
else:
encoder_outputs = ()
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
decoder_outputs = self.decoder(decoder_input_ids, encoder_hidden_states, **kwargs_decoder)
return decoder_outputs + encoder_outputs
def decode(
self,
encoder_input_ids,
decoder_prompt_ids=None,
device=torch.device("cpu"),
length=10,
do_sample=False,
temperature=1.0,
k=9,
p=0.,
repetition_penalty=1.,
**kwargs
):
""" Generic sequence generator for encoder-decoder models.
For encoder-decoders the generation consists in:
- Performing a forward pass through the encoder once;
- Pass the encoder's hidden states to a decoding mechanism that
repeatedly calls the decoder to generate sequences.
The method currently supports greedy decoding and sampling. See the
documentation of the `Sampler` class for more information about the
parameters related to sampling.
Params:
**encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length)
The sequence to encode.
**decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
**device**: (`optional`) `torch.device`
The device on which the prompt_ids will be initialized if not provided.
**length**: (`optional`) int
The length of the sequence to be generated.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling.
**temperature**: (`optional`) float
The value used to module the next token probabilities.
**k**: (`optional`) int
The parameter used for k-filtering.
**p**: (`optional`) float
The parameter for nucleus sampling. Must be between 0 and 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty.
"""
if decoder_prompt_ids is None:
decoder_prompt_ids = torch.tensor([[]], dtype=torch.long, device=device)
# When the model does not have a LM head `get_output_embeddings`
# returns `None`. We use this mechanism to determine whether we
# should proceed with decoding or not.
if self.decoder.get_output_embeddings() is None:
raise AttributeError("You tried do generated sequences with a decoder that does not have a LM Head.")
# The followings checks that the decoder is on the same device as the one
# that is specified. It only works for models that fit on one GPU.
decoder_device = next(self.decoder.parameters()).device
if decoder_device != decoder_prompt_ids.device:
warnings.warn(
"The decoder is not on the same device as the prompt. Expected {}, got {}.".format(
decoder_prompt_ids.device, decoder_device
)
)
kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs)
with torch.no_grad():
encoder_outputs = self.encoder(encoder_input_ids, **kwargs)
encoder_hidden_states = encoder_outputs[0]
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
sampler_config = {
"k": k,
"p": p,
"do_sample": do_sample,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
}
return self._greedy_decode_or_sample(
decoder_prompt_ids, length, sampler_config, **kwargs_decoder
)
def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **kwargs_decoder):
sampler = Sampler(**sampler_config)
with torch.no_grad():
generated_sequence = prompt_ids
for _ in trange(length):
arguments = self.decoder._prepare_inputs_for_decoding(generated_sequence, **kwargs_decoder)
outputs = self.decoder(**arguments)
next_tokens_logits = outputs[0][:, -1, :]
next_tokens = sampler.get_one_token(next_tokens_logits, generated_sequence)
generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1)
return generated_sequence.squeeze(0)
@staticmethod
def prepare_model_kwargs(**kwargs):
""" Prepare the encoder and decoder's keyword arguments.
Keyword arguments come in 3 flavors:
- encoder-specific (prefixed by `encoder_`)
- decoder-specific (prefixed by `decoder_`)
- those that apply to the model as whole.
We let the specific kwargs override the common ones in case of
conflict.
"""
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("encoder_")
and not argument.startswith("decoder_")
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
}
kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy()
kwargs_encoder.update(
decoder_kwargs = kwargs_common.copy()
encoder_kwargs = kwargs_common.copy()
encoder_kwargs.update(
{
argument[len("encoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("encoder_")
}
)
kwargs_decoder.update(
decoder_kwargs.update(
{
argument[len("decoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
}
)
decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None)
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[
0
] # output the last layer hidden state
else:
encoder_outputs = ()
# Decode
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get(
"attention_mask", None
)
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
return decoder_outputs + encoder_outputs
return encoder_kwargs, decoder_kwargs
class Model2Model(PreTrainedEncoderDecoder):
......
......@@ -36,7 +36,7 @@ from torch.nn.parameter import Parameter
from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
from .configuration_transfo_xl import TransfoXLConfig
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits, LogUniformSampler
from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__)
......@@ -908,3 +908,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
outputs = [softmax_output, None] + outputs
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
def get_output_embeddings(self):
""" Double-check if you are using adaptive softmax.
"""
if self.sample_softmax > 0:
return self.out_layer
else:
return self.crit.out_layers[-1]
......@@ -23,12 +23,14 @@ import json
import logging
import os
from io import open
import warnings
import six
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from tqdm import trange
from .configuration_utils import PretrainedConfig
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
......@@ -87,6 +89,93 @@ class PreTrainedModel(nn.Module):
def base_model(self):
return getattr(self, self.base_model_prefix, self)
def decode(self,
prompt_ids=None,
device=torch.device('cpu'),
length=10,
do_sample=False,
temperature=1.,
k=9,
p=0,
repetition_penalty=1,
**model_kwargs):
""" Generic sequence generator for single-stack models with a LM head.
The method currently supports greedy decoding and sampling. See the
documentation of the `Sampler` class for more information about the
parameters related to sampling.
Params:
**encoder_input_ids**: `torch.LongTensor` of shape (1, sequence_length)
The sequence to encode.
**decoder_prompt_ids**: (`optional`) `torch.LongTensor` of shape (1, sequence_length)
The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape (1,)
**device**: (`optional`) `torch.device`
The device on which the prompt_ids will be initialized if not provided.
**length**: (`optional`) int
The length of the sequence to be generated.
**do_sample**: (`optional`) bool
If set to `False` we use greedy decoding; otherwise sampling.
**temperature**: (`optional`) float
The value used to module the next token probabilities.
**k**: (`optional`) int
The parameter used for k-filtering.
**p**: (`optional`) float
The parameter for nucleus sampling. Must be between 0 and 1.
**repetition_penalty**: (`optional`) float
The parameter for repetition penalty.
"""
if prompt_ids is None:
prompt_ids = torch.tensor([[]], dtype=torch.long, device=device)
# When the model does not have a LM head `get_output_embeddings`
# returns `None`. We use this mechanism to determine whether we
# should proceed with decoding or not.
if self.get_output_embeddings() is None:
raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.")
# The followings checks that the model is on the same device as the one
# that is specified. It only works for models that fit on one GPU.
model_device = next(self.parameters()).device
if model_device != prompt_ids.device:
warnings.warn(
"The model is not on the same device as the prompts. Expected {}, got {}.".format(
prompt_ids.device, model_device
)
)
sampler_config = {
"k": k,
"p": p,
"do_sample": do_sample,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
}
return self._greedy_decode_or_sample(prompt_ids, length, sampler_config, **model_kwargs)
def _greedy_decode_or_sample(self, prompt_ids, length, sampler_config, **model_kwargs):
""" Generate text using greedy decoding or by sampling tokens."""
sampler = Sampler(**sampler_config)
generated_sequence = prompt_ids
with torch.no_grad():
for _ in trange(length):
arguments = self._prepare_inputs_for_decoding(generated_sequence, **model_kwargs)
outputs = self(**arguments)
next_tokens_logits = outputs[0][:, -1, :]
next_tokens = sampler.get_one_token(
next_tokens_logits, generated_sequence
)
generated_sequence = torch.cat((generated_sequence, next_tokens), dim=1)
return generated_sequence.squeeze(0)
def _prepare_inputs_for_decoding(self, input_ids, **kwargs):
arguments = {"input_ids": input_ids}
arguments.update(kwargs)
return arguments
def get_input_embeddings(self):
""" Get model's input embeddings
"""
......@@ -859,3 +948,143 @@ def prune_layer(layer, index, dim=None):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else:
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
class Sampler(object):
r""" Sampler is used to generate sequences of ids from logit inputs.
Greedy decoding, which consists in chosing the most probable token at each
step, is the default behaviour. Sampling with varying temperature, top_k
and nucleus filtering is also implemented.
Attributes:
**device**: ``torch.device``
Device on which the computations will be run.
**do_sample**: bool
Whether to sample or do greedy decoding.
**k**: int between 0 and vocab_size
Parameter for the top-k filtering
**p**: float between 0 and 1
Parameter for the nucleus filtering
**temperature**: strictly positive float
Parameter used to modulate the distribution over ids. Low temperatures
put more emphasis on highly probably token while high temperatures tend
to smooth the probability distribution.
**repetition_penalty**: strictly postitive float
The penalty applied to repeating ids
"""
def __init__(
self, do_sample=False, k=9, p=0.0, temperature=1.0, repetition_penalty=1.0
):
self.k = k
self.p = p
self.do_sample = do_sample
self.temperature = temperature
self.repetition_penalty = repetition_penalty
self.do_apply_repetition_penalty = True if repetition_penalty > 1 else False
if self.p > 1:
warnings.warn(
"""You are trying to apply nucleus filtering with a value of p greater than 1 ({}).
However p is a probability and its value must lie between 0 and 1. In effect, no filtering
will be applied. If this is not the behavior you expect, change the value of p.""".format(
self.p
)
)
def get_one_token(self, next_token_logits, past_sequence):
logits = self.apply_repetition_penalty(next_token_logits, past_sequence)
if self.do_sample:
logits = self.apply_temperature(logits)
logits = self.apply_top_k_filter(logits)
logits = self.apply_nucleus_filter(logits)
return torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
return torch.argmax(logits, dim=-1).unsqueeze(-1)
def apply_repetition_penalty(self, logits, past_sequence):
""" Apply a penalty to tokens that appear more than once in the
generated sequence.
.. Keskar, Nitish Shirish, et al. "Ctrl: A conditional transformer
language model for controllable generation." arXiv preprint
arXiv:1909.05858 (2019).
"""
if self.do_apply_repetition_penalty:
generated_token_idx = set(past_sequence[0].tolist())
for token_idx in generated_token_idx:
logits[0, token_idx] /= self.repetition_penalty
return logits
def apply_temperature(self, logits):
""" Shape the tokens' distribution through temperature. The higher the value
of the temperature, the more skewed towards high probability events the
distribution is.
.. Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning.
MIT press, 2016.
"""
# when dividing a float by 0, torch returns inf which in turns breaks the
# multinomial with an error message that is not very helpful. It is better
# for the user to break the execution and explain why.
if self.temperature == 0:
raise ZeroDivisionError(
"""You are trying to sample with a temperature equal to 0.
If you wanted to do greedy sampling, set instead `do_sample` to False.
Otherwise set the temperature to a value different from 0."""
)
return logits / self.temperature
def apply_top_k_filter(self, logits):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically we select the set of size k such that
the sum of its items' probabilities is maximum.
.. Fan, Angela, Mike Lewis, and Yann Dauphin. "Hierarchical neural
story generation." arXiv preprint arXiv:1805.04833 (2018).
"""
if self.k > 0:
vocabulary_size = logits.size(-1)
if self.k > vocabulary_size:
warnings.warn(
"""You provided a value for k ({}) that is larger than the vocabulary size ({}).
We adjusted k's value to the vocabulary size; if that was what you intended to do
we recommend setting k to 0 instead. It this is not the behavior you expected,
choose a value of k that is smaller than the vocabulary size.""".format(
self.k, vocabulary_size
)
)
self.k = vocabulary_size
indices_to_remove = logits < torch.topk(logits, self.k)[0][..., -1, None]
logits[indices_to_remove] = -float("Inf")
return logits
def apply_nucleus_filter(self, logits):
""" Use the probability distribution of the tokens to determine the set
to be sampled from. Specifically, choose the smallest set such that the
sum of its items' probabilities is greater than a number p in [0,1].
.. Holtzman, Ari, et al. "The curious case of neural text
degeneration." arXiv preprint arXiv:1904.09751 (2019).
"""
if self.p > 0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
sorted_probabilities = F.softmax(sorted_logits, dim=-1)
cumulative_probabilities = torch.cumsum(sorted_probabilities, dim=-1)
# Remove tokens with cumulative probability above the threshold,
# but keep the first token above the threshold.
sorted_indices_to_remove = cumulative_probabilities > self.p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = -float("Inf")
return logits
......@@ -657,6 +657,33 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
return outputs
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
mask_token = model_kwargs.pop("mask_token", None)
language = model_kwargs.pop("language", None)
input_ids = self._append_mask_token(input_ids, mask_token)
langs = self._create_language_embeddings(input_ids, language)
arguments = {"input_ids": input_ids, "langs": langs}
arguments.update(model_kwargs)
return arguments
@staticmethod
def _append_mask_token(sequence, mask_token_id):
""" Append a [MASK] token at the end of the sequence that the MLM model
is going to try to predict.
"""
if mask_token_id is not None:
tokens_to_append = torch.full((1, 1), mask_token_id, dtype=torch.long)
return torch.cat((sequence, tokens_to_append), dim=1)
return sequence
@staticmethod
def _create_language_embeddings(sequence, language):
if language is not None:
return torch.tensor([language] * sequence.shape[1]).view(1, -1)
return None
@add_start_docstrings("""XLM Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
......
......@@ -972,6 +972,40 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
def _prepare_inputs_for_decoding(self, input_ids, **model_kwargs):
input_ids = self._add_dummy_token(input_ids)
perm_mask = self._create_perm_mask(input_ids)
target_mapping = self._create_target_mapping(input_ids)
arguments = {
"input_ids": input_ids,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
}
return arguments
@staticmethod
def _add_dummy_token(sequence):
dummy = torch.zeros((sequence.size(0), 1), dtype=torch.long)
return torch.cat((sequence, dummy), dim=1)
@staticmethod
def _create_perm_mask(sequence):
mask = torch.zeros(
(sequence.shape[0], sequence.shape[1], sequence.shape[1]),
dtype=torch.float,
)
mask[:, :, -1] = 1.0 # Previous tokens don't see last token
return mask
@staticmethod
def _create_target_mapping(sequence):
target_mapping = torch.zeros(
(sequence.shape[0], 1, sequence.shape[1]),
dtype=torch.float,
)
target_mapping[0, 0, -1] = 1.0 # predict last token
return target_mapping
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """,
......
# coding=utf-8
import sys
import unittest
import numpy as np
import pytest
from transformers import is_torch_available
if is_torch_available():
import torch
from transformers import (
BertConfig,
BertModel,
GPT2Config,
GPT2LMHeadModel,
OpenAIGPTConfig,
OpenAIGPTLMHeadModel,
TransfoXLConfig,
TransfoXLLMHeadModel,
XLMConfig,
XLMWithLMHeadModel,
XLNetConfig,
XLNetLMHeadModel,
Model2Model,
)
from transformers.modeling_utils import Sampler
else:
pytestmark = pytest.mark.skip("Require Torch")
class SamplerTest(unittest.TestCase):
def test_nucleus_sampling(self):
inf = -float("Inf")
test_cases = (
{
"p": 0,
"logits": torch.tensor([0.3, 0.1, 0.2]),
"expected": torch.tensor([0.3, 0.1, 0.2]),
},
{
"p": 0.01,
"logits": torch.tensor([0.3, 0.1, 0.2]),
"expected": torch.tensor([0.3, inf, inf]),
},
{
"p": 1,
"logits": torch.tensor([0.3, 0.1, 0.2]),
"expected": torch.tensor([0.3, 0.1, 0.2]),
},
{
"p": 0.2,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, inf, inf]),
},
{
"p": 0.71,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, inf, 0.2]),
},
{
"p": 0.71,
"logits": torch.tensor([0.1, 0.7, 0.2]),
"expected": torch.tensor([inf, 0.7, 0.2]),
},
{
"p": 0.71,
"logits": torch.tensor([0.7, 0.2, 0.1]),
"expected": torch.tensor([0.7, 0.2, inf]),
},
{
"p": 0.91,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, 0.1, 0.2]),
},
)
for case in test_cases:
config = {
"do_sample": True,
"temperature": 1.0,
"k": 0,
"p": case["p"],
"repetition_penalty": 1.0,
}
sampler = Sampler(**config)
filtered_logits = sampler.apply_nucleus_filter(case["logits"])
np.testing.assert_array_equal(case["expected"].numpy(), filtered_logits.numpy())
def test_top_k_filter(self):
inf = -float("Inf")
test_cases = (
{
"k": 0,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, 0.1, 0.2]),
},
{
"k": 1,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, inf, inf]),
},
{
"k": 2,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, inf, 0.2]),
},
{
"k": 3,
"logits": torch.tensor([0.7, 0.1, 0.2]),
"expected": torch.tensor([0.7, 0.1, 0.2]),
},
)
for case in test_cases:
config = {
"do_sample": True,
"temperature": 1.0,
"k": case["k"],
"p": 0,
"repetition_penalty": 1.0,
}
sampler = Sampler(**config)
filtered_logits = sampler.apply_top_k_filter(case["logits"])
np.testing.assert_array_equal(case["expected"].numpy(), filtered_logits.numpy())
@pytest.mark.skipif(sys.version_info < (3, 2), reason="assertWarns() requires Python >= 3.2")
def test_wrong_k_value(self):
case = {"k": 10, "vocab_size": 5}
config = {
"do_sample": True,
"temperature": 1.0,
"k": case["k"],
"p": 0,
"repetition_penalty": 1.0,
}
sampler = Sampler(**config)
next_token_logits = torch.rand(case["vocab_size"]).unsqueeze(0)
past_sequence = torch.tensor([])
with self.assertWarns(UserWarning):
_ = sampler.get_one_token(next_token_logits, past_sequence)
def test_zero_temperature(self):
temperature = 0
config = {
"do_sample": True,
"temperature": temperature,
"k": 0,
"p": 0,
"repetition_penalty": 1.0,
}
sampler = Sampler(**config)
next_token_logits = torch.rand(10).unsqueeze(0)
past_sequence = torch.tensor([])
with self.assertRaises(ZeroDivisionError):
_ = sampler.get_one_token(next_token_logits, past_sequence)
class SamplerSingleStackTest(unittest.TestCase):
def test_raises_exception_when_no_LM_head(self):
models = [BertModel(BertConfig())]
for model in models:
with self.assertRaises(AttributeError):
model.decode()
@pytest.mark.slow
def test_forward_pass_and_output_length(self):
models = {
"XLNet": XLNetLMHeadModel(XLNetConfig()),
"XLM": XLMWithLMHeadModel(XLMConfig()),
"TransfoXL": TransfoXLLMHeadModel(TransfoXLConfig()),
"GPT2": GPT2LMHeadModel(GPT2Config()),
"GPT": OpenAIGPTLMHeadModel(OpenAIGPTConfig()),
}
kwargs = {
"XLNet": {},
"XLM": {"mask_token": 0},
"TransfoXL": {},
"GPT2": {},
"GPT": {},
}
prompt = torch.tensor([[1, 2, 3]], dtype=torch.long)
generated_length = 5
expected_length = 8
for name, model in models.items():
kwargs_model = kwargs[name]
output = model.decode(prompt_ids=prompt, length=generated_length, **kwargs_model)
self.assertEqual(len(output), expected_length)
class SamplerEncoderDecoderTest(unittest.TestCase):
@pytest.mark.slow
def test_forward_pass_and_output_length(self):
model = Model2Model.from_pretrained("bert-base-uncased")
encoder_input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
prompt = torch.tensor([[1, 2, 3]], dtype=torch.long)
generated_length = 5
expected_length = 8
output = model.decode(
encoder_input_ids,
decoder_prompt_ids=prompt,
k=2,
p=0.5,
repetition_penalty=2,
length=generated_length,
)
self.assertEqual(len(output), expected_length)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment