"web/vscode:/vscode.git/clone" did not exist on "d4273eadcd1bb7fa82540fc10d1f867c227924f7"
Commit 300ec300 authored by thomwolf's avatar thomwolf
Browse files

fixing run_generation example - using torch.no_grad

parent 1c377468
......@@ -87,11 +87,11 @@ def prepare_ctrl_input(args, _, tokenizer, prompt_text):
logger.info(
"WARNING! You are not starting your generation from a control code so you won't get good results"
)
return prompt_text, {}
return prompt_text
def prepare_xlm_input(args, model, tokenizer, prompt_text):
kwargs = {"language": None, "mask_token_id": None}
# kwargs = {"language": None, "mask_token_id": None}
# Set the language
use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
......@@ -107,14 +107,15 @@ def prepare_xlm_input(args, model, tokenizer, prompt_text):
+ str(list(available_languages))
+ " >>> "
)
kwargs["language"] = tokenizer.lang2id[language]
# kwargs["language"] = tokenizer.lang2id[language]
# TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
# XLM masked-language modeling (MLM) models need masked token
is_xlm_mlm = "mlm" in args.model_name_or_path
if is_xlm_mlm:
kwargs["mask_token_id"] = tokenizer.mask_token_id
# is_xlm_mlm = "mlm" in args.model_name_or_path
# if is_xlm_mlm:
# kwargs["mask_token_id"] = tokenizer.mask_token_id
return prompt_text, kwargs
return prompt_text
def prepare_xlnet_input(args, _, tokenizer, prompt_text):
......@@ -179,8 +180,8 @@ def main():
try:
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
except KeyError as ke:
raise ke(
except KeyError:
raise KeyError(
"the model {} you specified is not supported. You are welcome to add it and open a PR :)"
)
......@@ -197,10 +198,9 @@ def main():
# Different models need different input formatting and/or extra arguments
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
model_kwargs = {}
if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
prompt_text, model_kwargs = prepare_input(args, model, tokenizer, prompt_text)
prompt_text = prepare_input(args, model, tokenizer, prompt_text)
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors='pt')
output_sequences = model.generate(
......@@ -210,14 +210,11 @@ def main():
top_k=args.k,
top_p=args.p,
repetition_penalty=args.repetition_penalty,
**model_kwargs,
)
generated_sequence = output_sequences.tolist()[
encoded_prompt.size(1) :
] # adapted to case where num_samples > 1
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
text = text[: text.find(args.stop_token) if args.stop_token else None]
generated_sequence = output_sequences.tolist()
text = [tokenizer.decode(seq, clean_up_tokenization_spaces=True) for seq in generated_sequence]
# text = text[: text.find(args.stop_token) if args.stop_token else None]
print(text)
......
......@@ -113,8 +113,8 @@ class XLMConfig(PretrainedConfig):
summary_first_dropout=0.1,
start_n_top=5,
end_n_top=5,
mask_token_id = 0,
lang_id = 0,
mask_token_id=0,
lang_id=0,
**kwargs):
"""Constructs XLMConfig.
"""
......
......@@ -489,7 +489,7 @@ class PreTrainedModel(nn.Module):
def generate(self, input_ids=None, max_length=None, do_sample=None, num_beams=None,
temperature=None, top_k=None, top_p=None, repetition_penalty=None,
bos_token_id=None, pad_token_id=None, eos_token_ids=None,
length_penalty=None, num_return_sequences=None, **model_kwargs):
length_penalty=None, num_return_sequences=None):
""" Sequence generator for models with a LM head.
The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
......@@ -519,7 +519,8 @@ class PreTrainedModel(nn.Module):
# We cannot generate if the model does not have a LM head
if self.get_output_embeddings() is None:
raise AttributeError("You tried do generated sequences with a model that does not have a LM Head.")
raise AttributeError("You tried to generate sequences with a model that does not have a LM Head."
"Please use another model class (e.g. `OpenAIGPTLMHeadModel`)")
max_length = max_length if max_length is not None else self.config.max_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
......@@ -544,7 +545,7 @@ class PreTrainedModel(nn.Module):
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
# assert temperature > 0, "`temperature` should be strictely positive."
# assert temperature >= 0, "`temperature` should be positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
......@@ -576,13 +577,11 @@ class PreTrainedModel(nn.Module):
output = self._generate_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, effective_batch_size,
length_penalty, num_beams, vocab_size,
**model_kwargs)
length_penalty, num_beams, vocab_size)
else:
output = self._generate_no_beam_search(input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, effective_batch_size,
**model_kwargs)
pad_token_id, eos_token_ids, effective_batch_size)
if num_return_sequences != 1:
output = output.view(batch_size, num_return_sequences, -1)
......@@ -590,19 +589,18 @@ class PreTrainedModel(nn.Module):
def _generate_no_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size,
**model_kwargs):
pad_token_id, eos_token_ids, batch_size):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
"""
# current position / max lengths / length of generated sentences / unfinished sentences
unfinished_sents = input_ids.new(batch_size).fill_(1)
# cache compute states
# TODO: add cached compute states
pasts = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs)
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
......@@ -614,7 +612,7 @@ class PreTrainedModel(nn.Module):
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
if temperature > 0 and temperature != 1.0:
next_token_logits = next_token_logits / temperature
# Top-p/top-k filtering
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
......@@ -644,8 +642,7 @@ class PreTrainedModel(nn.Module):
def _generate_beam_search(self, input_ids, cur_len, max_length, do_sample,
temperature, top_k, top_p, repetition_penalty,
pad_token_id, eos_token_ids, batch_size,
length_penalty, num_beams, vocab_size,
**model_kwargs):
length_penalty, num_beams, vocab_size):
""" Generate sequences for each example with beam search.
"""
# Expand input to num beams
......@@ -667,7 +664,7 @@ class PreTrainedModel(nn.Module):
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts, **model_kwargs)
model_inputs = self.prepare_inputs_for_generation(input_ids, pasts=pasts)
scores = self(**model_inputs)[0] # (batch_size * num_beams, cur_len, vocab_size)
scores = scores[:, -1, :] # (batch_size * num_beams, vocab_size)
......@@ -679,7 +676,7 @@ class PreTrainedModel(nn.Module):
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0:
if temperature > 0 and temperature != 1.0:
scores = scores / temperature
# Top-p/top-k filtering
scores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) # (batch_size * num_beams, vocab_size)
......
......@@ -639,9 +639,9 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
def get_output_embeddings(self):
return self.pred_layer.proj
def prepare_inputs_for_generation(self, input_ids, **model_kwargs):
mask_token_id = model_kwargs['mask_token_id'] if 'mask_token_id' in model_kwargs else self.config.mask_token_id
lang_id = model_kwargs['lang_id'] if 'lang_id' in model_kwargs else self.config.lang_id
def prepare_inputs_for_generation(self, input_ids, **kwargs):
mask_token_id = self.config.mask_token_id
lang_id = self.config.lang_id
mask_token = torch.full((1, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
input_ids = torch.cat([input_ids, mask_token], dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment