Commit ca2047bc authored by Patrick von Platen's avatar Patrick von Platen
Browse files

refactor variable naming and improve tf generate in line with torch generate

parent 41b437ea
This diff is collapsed.
...@@ -722,6 +722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -722,6 +722,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature temperature = temperature if temperature is not None else self.config.temperature
...@@ -852,7 +853,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -852,7 +853,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
device=next(self.parameters()).device, device=next(self.parameters()).device,
) )
cur_len = 1 cur_len = 1
self.model.decoder.generation_mode = True
# put model in generation mode if it has one
if hasattr(self.model, "generation_mode"):
self.model.decoder.generation_mode = True
else: else:
encoder_inputs = None encoder_inputs = None
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
...@@ -860,44 +864,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -860,44 +864,44 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if num_beams > 1: if num_beams > 1:
output = self._generate_beam_search( output = self._generate_beam_search(
input_ids, input_ids,
cur_len, cur_len=cur_len,
max_length, max_length=max_length,
min_length, min_length=min_length,
do_sample, do_sample=do_sample,
early_stopping, early_stopping=early_stopping,
temperature, temperature=temperature,
top_k, top_k=top_k,
top_p, top_p=top_p,
repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
bos_token_id, bos_token_id=bos_token_id,
pad_token_id, pad_token_id=pad_token_id,
eos_token_ids, eos_token_ids=eos_token_ids,
effective_batch_size, batch_size=effective_batch_size,
num_return_sequences, num_return_sequences=num_return_sequences,
length_penalty, length_penalty=length_penalty,
num_beams, num_beams=num_beams,
vocab_size, vocab_size=vocab_size,
encoder_inputs, encoder_inputs=encoder_inputs,
attention_mask, attention_mask=attention_mask,
) )
else: else:
output = self._generate_no_beam_search( output = self._generate_no_beam_search(
input_ids, input_ids,
cur_len, cur_len=cur_len,
max_length, max_length=max_length,
min_length, min_length=min_length,
do_sample, do_sample=do_sample,
temperature, temperature=temperature,
top_k, top_k=top_k,
top_p, top_p=top_p,
repetition_penalty, repetition_penalty=repetition_penalty,
no_repeat_ngram_size, no_repeat_ngram_size=no_repeat_ngram_size,
pad_token_id, pad_token_id=pad_token_id,
eos_token_ids, eos_token_ids=eos_token_ids,
effective_batch_size, batch_size=effective_batch_size,
encoder_inputs, encoder_inputs=encoder_inputs,
attention_mask, attention_mask=attention_mask,
) )
return output return output
...@@ -1157,24 +1161,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1157,24 +1161,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_sent_beam = [] next_sent_beam = []
# next tokens for this sentence # next tokens for this sentence
for i, (idx, score) in enumerate(zip(next_tokens[batch_idx], next_scores[batch_idx])): for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx])
):
# get beam and word IDs # get beam and word IDs
beam_id = idx // vocab_size beam_id = beam_token_id // vocab_size
token_id = idx % vocab_size token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence # add to generated hypotheses if end of sentence
if (eos_token_ids is not None) and (token_id.item() in eos_token_ids): if (eos_token_ids is not None) and (token_id.item() in eos_token_ids):
# when passed to num_beams hypotheses, continue # if beam_token does not belong to top num_beams tokens, it should not be added
if i >= num_beams: is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
if is_beam_token_worse_than_top_num_beams:
continue continue
generated_hyps[batch_idx].add( generated_hyps[batch_idx].add(
input_ids[effective_beam_id].clone(), score.item(), input_ids[effective_beam_id].clone(), beam_token_score.item(),
) )
else: else:
# add next predicted word if it is not eos_token # add next predicted word if it is not eos_token
next_sent_beam.append((score, token_id, effective_beam_id)) next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# the beam for next step is full # the beam for next step is full
if len(next_sent_beam) == num_beams: if len(next_sent_beam) == num_beams:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment