Unverified Commit ed9b8481 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1840 from huggingface/generation_sampler

[WIP] Sampling sequence generator for transformers
parents 7e17f09f f86ed231
...@@ -20,14 +20,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -20,14 +20,10 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import argparse import argparse
import logging import logging
from tqdm import trange
import torch import torch
import torch.nn.functional as F
import numpy as np import numpy as np
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig
from transformers import GPT2LMHeadModel, GPT2Tokenizer from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer from transformers import XLNetLMHeadModel, XLNetTokenizer
...@@ -36,22 +32,22 @@ from transformers import CTRLLMHeadModel, CTRLTokenizer ...@@ -36,22 +32,22 @@ from transformers import CTRLLMHeadModel, CTRLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer from transformers import XLMWithLMHeadModel, XLMTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
level = logging.INFO) datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig, CTRLConfig)), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer), "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
'ctrl': (CTRLLMHeadModel, CTRLTokenizer), "ctrl": (CTRLLMHeadModel, CTRLTokenizer),
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'xlnet': (XLNetLMHeadModel, XLNetTokenizer), "xlnet": (XLNetLMHeadModel, XLNetTokenizer),
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer), "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
'xlm': (XLMWithLMHeadModel, XLMTokenizer), "xlm": (XLMWithLMHeadModel, XLMTokenizer),
} }
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
...@@ -75,81 +71,79 @@ def set_seed(args): ...@@ -75,81 +71,79 @@ def set_seed(args):
if args.n_gpu > 0: if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)
#
# Functions to prepare models' input
#
def prepare_ctrl_input(args, _, tokenizer, prompt_text):
if args.temperature > 0.7:
logger.info(
"CTRL typically works better with lower temperatures (and lower top_k)."
)
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
logger.info(
"WARNING! You are not starting your generation from a control code so you won't get good results"
)
return prompt_text
def prepare_xlm_input(args, model, tokenizer, prompt_text):
# kwargs = {"language": None, "mask_token_id": None}
# Set the language
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
if hasattr(model.config, "lang2id") and use_lang_emb:
available_languages = model.config.lang2id.keys()
if args.xlm_language in available_languages:
language = args.xlm_language
else:
language = None
while language not in available_languages:
language = input(
"Using XLM. Select language in "
+ str(list(available_languages))
+ " >>> "
)
# kwargs["language"] = tokenizer.lang2id[language]
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
# XLM masked-language modeling (MLM) models need masked token
# is_xlm_mlm = "mlm" in args.model_name_or_path
# if is_xlm_mlm:
# kwargs["mask_token_id"] = tokenizer.mask_token_id
return prompt_text
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
return prompt_text, {}
def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text
return prompt_text, {}
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering PREPROCESSING_FUNCTIONS = {
Args: "ctrl": prepare_ctrl_input,
logits: logits distribution shape (batch size x vocabulary size) "xlm": prepare_xlm_input,
top_k > 0: keep only top k tokens with highest probability (top-k filtering). "xlnet": prepare_xlnet_input,
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). "transfo-xl": prepare_transfoxl_input,
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) }
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
"""
top_k = min(top_k, logits.size(-1)) # Safety check def adjust_length_to_model(length, max_sequence_length):
if top_k > 0: if length < 0 and max_sequence_length > 0:
# Remove all tokens with a probability less than the last token of the top-k length = max_sequence_length
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] elif 0 < max_sequence_length < length:
logits[indices_to_remove] = filter_value length = max_sequence_length # No generation bigger than model size
elif length < 0:
if top_p > 0.0: length = MAX_LENGTH # avoid infinite loop
sorted_logits, sorted_indices = torch.sort(logits, descending=True) return length
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
context = torch.tensor(context, dtype=torch.long, device=device)
context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context
with torch.no_grad():
for _ in trange(length):
inputs = {'input_ids': generated}
if is_xlnet:
# XLNet is a direct (predict same token, not next token) and bi-directional model by default
# => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
target_mapping[0, 0, -1] = 1.0 # predict last token
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}
if is_xlm_mlm and xlm_mask_token:
# XLM MLM models are direct models (predict same token, not next token)
# => need one additional dummy token in the input (will be masked and guessed)
input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
inputs = {'input_ids': input_ids}
if xlm_lang is not None:
inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)
outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
# repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for i in range(num_samples):
for _ in set(generated[i].tolist()):
next_token_logits[i, _] /= repetition_penalty
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
if temperature == 0: # greedy sampling:
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token), dim=1)
return generated
def main(): def main():
...@@ -157,108 +151,76 @@ def main(): ...@@ -157,108 +151,76 @@ def main():
parser.add_argument("--model_type", default=None, type=str, required=True, parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--prompt", type=str, default="") parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--xlm_lang", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--length", type=int, default=20) parser.add_argument("--length", type=int, default=20)
parser.add_argument("--num_samples", type=int, default=1) parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
parser.add_argument("--temperature", type=float, default=1.0,
help="temperature of 0 implies greedy sampling") parser.add_argument("--temperature", type=float, default=1.0, help="temperature of 1.0 has no effect, lower tend toward greedy sampling")
parser.add_argument("--repetition_penalty", type=float, default=1.0, parser.add_argument("--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2")
help="primarily useful for CTRL model; in that case, use 1.2") parser.add_argument("--k", type=int, default=0)
parser.add_argument("--top_k", type=int, default=0) parser.add_argument("--p", type=float, default=0.9)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--no_cuda", action='store_true', parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.")
help="Avoid using CUDA when available") parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization") parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument('--stop_token', type=str, default=None, parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
help="Token at which text generation is stopped")
args = parser.parse_args() args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.device = torch.device(
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
)
args.n_gpu = torch.cuda.device_count() args.n_gpu = torch.cuda.device_count()
set_seed(args) set_seed(args)
args.model_type = args.model_type.lower() # Initialize the model and tokenizer
model_class, tokenizer_class = MODEL_CLASSES[args.model_type] try:
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
except KeyError:
raise KeyError(
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
)
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path) model = model_class.from_pretrained(args.model_name_or_path)
model.to(args.device) model.to(args.device)
model.eval()
if args.length < 0 and model.config.max_position_embeddings > 0:
args.length = model.config.max_position_embeddings
elif 0 < model.config.max_position_embeddings < args.length:
args.length = model.config.max_position_embeddings # No generation bigger than model size
elif args.length < 0:
args.length = MAX_LENGTH # avoid infinite loop
args.length = adjust_length_to_model(
args.length, max_sequence_length=model.config.max_position_embeddings
)
logger.info(args) logger.info(args)
if args.model_type in ["ctrl"]:
if args.temperature > 0.7: prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
logger.info('CTRL typically works better with lower temperatures (and lower top_k).')
# Different models need different input formatting and/or extra arguments
while True: requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
xlm_lang = None if requires_preprocessing:
# XLM Language usage detailed in the issues #1414 prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
if args.model_type in ["xlm"] and hasattr(tokenizer, 'lang2id') and hasattr(model.config, 'use_lang_emb') \ prompt_text = prepare_input(args, model, tokenizer, prompt_text)
and model.config.use_lang_emb: encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt')
if args.xlm_lang:
language = args.xlm_lang output_sequences = model.generate(
else: input_ids=encoded_prompt,
language = None max_length=args.length,
while language not in tokenizer.lang2id.keys(): temperature=args.temperature,
language = input("Using XLM. Select language in " + str(list(tokenizer.lang2id.keys())) + " >>> ") top_k=args.k,
xlm_lang = tokenizer.lang2id[language] top_p=args.p,
repetition_penalty=args.repetition_penalty,
# XLM masked-language modeling (MLM) models need masked token (see details in sample_sequence) )
is_xlm_mlm = args.model_type in ["xlm"] and 'mlm' in args.model_name_or_path
if is_xlm_mlm: # Batch size == 1. to add more examples please use num_return_sequences > 1
xlm_mask_token = tokenizer.mask_token_id generated_sequence = output_sequences[0].tolist()
else: text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
xlm_mask_token = None text = text[: t.find(args.stop_token) if args.stop_token else None]
raw_text = args.prompt if args.prompt else input("Model prompt >>> ") print(text)
if args.model_type in ["transfo-xl", "xlnet"]:
# Models with memory likes to have a long prompt for short inputs.
raw_text = (args.padding_text if args.padding_text else PADDING_TEXT) + raw_text
context_tokens = tokenizer.encode(raw_text, add_special_tokens=False)
if args.model_type == "ctrl":
if not any(context_tokens[0] == x for x in tokenizer.control_codes.values()):
logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
out = sample_sequence(
model=model,
context=context_tokens,
num_samples=args.num_samples,
length=args.length,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
is_xlnet=bool(args.model_type == "xlnet"),
is_xlm_mlm=is_xlm_mlm,
xlm_mask_token=xlm_mask_token,
xlm_lang=xlm_lang,
device=args.device,
)
out = out[:, len(context_tokens):].tolist()
for o in out:
text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
if args.stop_token:
index = text.find(args.stop_token)
if index == -1:
index = None
text = text[:index]
print(text)
if args.prompt:
break
return text return text
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -56,8 +56,24 @@ class PretrainedConfig(object): ...@@ -56,8 +56,24 @@ class PretrainedConfig(object):
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop('use_bfloat16', False) self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
self.pruned_heads = kwargs.pop('pruned_heads', {}) self.pruned_heads = kwargs.pop('pruned_heads', {})
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_decoder = kwargs.pop('is_decoder', False) self.is_decoder = kwargs.pop('is_decoder', False)
# Parameters for sequence generation
self.max_length = kwargs.pop('max_length', 20)
self.do_sample = kwargs.pop('do_sample', False)
self.num_beams = kwargs.pop('num_beams', 1)
self.temperature = kwargs.pop('temperature', 1.0)
self.top_k = kwargs.pop('top_k', 50)
self.top_p = kwargs.pop('top_p', 1.0)
self.repetition_penalty = kwargs.pop('repetition_penalty', 1.0)
self.bos_token_id = kwargs.pop('bos_token_id', 0)
self.pad_token_id = kwargs.pop('pad_token_id', 0)
self.eos_token_ids = kwargs.pop('eos_token_ids', 0)
self.length_penalty = kwargs.pop('length_penalty', 1.)
self.num_return_sequences = kwargs.pop('num_return_sequences', 1)
# Fine-tuning task arguments # Fine-tuning task arguments
self.finetuning_task = kwargs.pop('finetuning_task', None) self.finetuning_task = kwargs.pop('finetuning_task', None)
self.num_labels = kwargs.pop('num_labels', 2) self.num_labels = kwargs.pop('num_labels', 2)
......
...@@ -110,6 +110,8 @@ class XLMConfig(PretrainedConfig): ...@@ -110,6 +110,8 @@ class XLMConfig(PretrainedConfig):
summary_first_dropout=0.1, summary_first_dropout=0.1,
start_n_top=5, start_n_top=5,
end_n_top=5, end_n_top=5,
mask_token_id=0,
lang_id=0,
**kwargs): **kwargs):
"""Constructs XLMConfig. """Constructs XLMConfig.
""" """
...@@ -143,6 +145,8 @@ class XLMConfig(PretrainedConfig): ...@@ -143,6 +145,8 @@ class XLMConfig(PretrainedConfig):
self.summary_first_dropout = summary_first_dropout self.summary_first_dropout = summary_first_dropout
self.start_n_top = start_n_top self.start_n_top = start_n_top
self.end_n_top = end_n_top self.end_n_top = end_n_top
self.mask_token_id = mask_token_id
self.lang_id = lang_id
if "n_words" in kwargs: if "n_words" in kwargs:
self.n_words = kwargs["n_words"] self.n_words = kwargs["n_words"]
......
...@@ -18,9 +18,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -18,9 +18,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging import logging
import os import os
import warnings
import torch import torch
from torch import nn from torch import nn
from tqdm import trange
from .modeling_auto import AutoModel, AutoModelWithLMHead from .modeling_auto import AutoModel, AutoModelWithLMHead
...@@ -119,8 +121,7 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -119,8 +121,7 @@ class PreTrainedEncoderDecoder(nn.Module):
kwargs_common = { kwargs_common = {
argument: value argument: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if not argument.startswith("encoder_") if not argument.startswith("encoder_") and not argument.startswith("decoder_")
and not argument.startswith("decoder_")
} }
kwargs_decoder = kwargs_common.copy() kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy() kwargs_encoder = kwargs_common.copy()
...@@ -220,49 +221,56 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -220,49 +221,56 @@ class PreTrainedEncoderDecoder(nn.Module):
Indices of decoder input sequence tokens in the vocabulary. Indices of decoder input sequence tokens in the vocabulary.
kwargs: (`optional`) Remaining dictionary of keyword arguments. kwargs: (`optional`) Remaining dictionary of keyword arguments.
""" """
# keyword arguments come in 3 flavors: encoder-specific (prefixed by kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs)
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole. # Encode if needed (training, first prediction pass)
# We let the specific kwargs override the common ones in case of conflict. encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
else:
encoder_outputs = ()
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
decoder_outputs = self.decoder(decoder_input_ids, encoder_hidden_states, **kwargs_decoder)
return decoder_outputs + encoder_outputs
@staticmethod
def prepare_model_kwargs(**kwargs):
""" Prepare the encoder and decoder's keyword arguments.
Keyword arguments come in 3 flavors:
- encoder-specific (prefixed by `encoder_`)
- decoder-specific (prefixed by `decoder_`)
- those that apply to the model as whole.
We let the specific kwargs override the common ones in case of
conflict.
"""
kwargs_common = { kwargs_common = {
argument: value argument: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if not argument.startswith("encoder_") if not argument.startswith("encoder_") and not argument.startswith("decoder_")
and not argument.startswith("decoder_")
} }
kwargs_decoder = kwargs_common.copy() decoder_kwargs = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy() encoder_kwargs = kwargs_common.copy()
kwargs_encoder.update( encoder_kwargs.update(
{ {
argument[len("encoder_") :]: value argument[len("encoder_") :]: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if argument.startswith("encoder_") if argument.startswith("encoder_")
} }
) )
kwargs_decoder.update( decoder_kwargs.update(
{ {
argument[len("decoder_") :]: value argument[len("decoder_") :]: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if argument.startswith("decoder_") if argument.startswith("decoder_")
} }
) )
decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None)
# Encode if needed (training, first prediction pass) return encoder_kwargs, decoder_kwargs
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
else:
encoder_outputs = ()
# Decode
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get(
"attention_mask", None
)
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
return decoder_outputs + encoder_outputs
class Model2Model(PreTrainedEncoderDecoder): class Model2Model(PreTrainedEncoderDecoder):
......
...@@ -36,7 +36,7 @@ from torch.nn.parameter import Parameter ...@@ -36,7 +36,7 @@ from torch.nn.parameter import Parameter
from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary
from .configuration_transfo_xl import TransfoXLConfig from .configuration_transfo_xl import TransfoXLConfig
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits, LogUniformSampler
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -908,3 +908,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -908,3 +908,11 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
outputs = [softmax_output, None] + outputs outputs = [softmax_output, None] + outputs
return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions) return outputs # (loss), logits or None if labels is not None (speed up adaptive softmax), new_mems, (all hidden states), (all attentions)
def get_output_embeddings(self):
""" Double-check if you are using adaptive softmax.
"""
if self.sample_softmax > 0:
return self.out_layer
else:
return self.crit.out_layers[-1]
This diff is collapsed.
...@@ -649,6 +649,18 @@ class XLMWithLMHeadModel(XLMPreTrainedModel): ...@@ -649,6 +649,18 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.pred_layer.proj return self.pred_layer.proj
def prepare_inputs_for_generation(self, input_ids, **kwargs):
mask_token_id = self.config.mask_token_id
lang_id = self.config.lang_id
mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
input_ids = torch.cat([input_ids, mask_token], dim=1)
if lang_id is not None:
langs = torch.full_like(input_ids, lang_id)
else:
langs = None
return {"input_ids": input_ids, "langs": langs}
def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None, def forward(self, input_ids=None, attention_mask=None, langs=None, token_type_ids=None, position_ids=None,
lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None): lengths=None, cache=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
......
...@@ -947,6 +947,30 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -947,6 +947,30 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_loss return self.lm_loss
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
# Add dummy token at the end (no attention on this one)
dummy_token = torch.zeros((1, 1), dtype=torch.long, device=input_ids.device)
input_ids = torch.cat([input_ids, dummy_token], dim=1)
# Build permutation mask so that previous tokens don't see last token
perm_mask = torch.zeros(
(input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]),
dtype=torch.float, device=input_ids.device
)
perm_mask[:, :, -1] = 1.0
# We'll only predict the last token
target_mapping = torch.zeros(
(input_ids.shape[0], 1, input_ids.shape[1]),
dtype=torch.float, device=input_ids.device
)
target_mapping[0, 0, -1] = 1.0
return {"input_ids": input_ids,
"perm_mask": perm_mask,
"target_mapping": target_mapping
}
def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None, def forward(self, input_ids=None, attention_mask=None, mems=None, perm_mask=None, target_mapping=None,
token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None): token_type_ids=None, input_mask=None, head_mask=None, inputs_embeds=None, labels=None):
transformer_outputs = self.transformer(input_ids, transformer_outputs = self.transformer(input_ids,
......
...@@ -773,7 +773,7 @@ class PreTrainedTokenizer(object): ...@@ -773,7 +773,7 @@ class PreTrainedTokenizer(object):
padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length. padding index, up to their max length. If no max length is specified, the padding is done up to the model's max length.
The tokenizer padding sides are handled by the following strings: The tokenizer padding sides are handled by the following strings:
- 'left': pads on the left of the sequences - 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences - 'right': pads on the right of the sequences
Defaults to False: no padding. Defaults to False: no padding.
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
or PyTorch torch.Tensor instead of a list of python integers. or PyTorch torch.Tensor instead of a list of python integers.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment