@@ -556,46 +556,81 @@ class PreTrainedModel(nn.Module):
...
@@ -556,46 +556,81 @@ class PreTrainedModel(nn.Module):
length_penalty=None,
length_penalty=None,
num_return_sequences=None,
num_return_sequences=None,
):
):
""" Sequence generator for models with a LM head.
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
and beam-search.
Adapted in part from Facebook's XLM beam search code: https://github.com/facebookresearch/XLM
Adapted in part from `Facebook's XLM beam search code`_.
# We cannot generate if the model does not have a LM head
# We cannot generate if the model does not have a LM head
ifself.get_output_embeddings()isNone:
ifself.get_output_embeddings()isNone:
raiseAttributeError(
raiseAttributeError(
"You tried to generate sequences with a model that does not have a LM Head."
"You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`)"
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`, `XLNetLMHeadModel`, `GPT2LMHeadModel`, `CTRLLMHeadModel`, `T5WithLMHeadModel`, `TransfoXLLMHeadModel`)"