Commit c0d9dd3b authored by Patrick von Platen's avatar Patrick von Platen
Browse files

refactored code a bit and made more generic

parent d8e2b3c5
...@@ -69,6 +69,7 @@ class PretrainedConfig(object): ...@@ -69,6 +69,7 @@ class PretrainedConfig(object):
# Parameters for sequence generation # Parameters for sequence generation
self.max_length = kwargs.pop("max_length", 20) self.max_length = kwargs.pop("max_length", 20)
self.min_length = kwargs.pop("max_length", 0)
self.do_sample = kwargs.pop("do_sample", False) self.do_sample = kwargs.pop("do_sample", False)
self.early_stopping = kwargs.pop("early_stopping", False) self.early_stopping = kwargs.pop("early_stopping", False)
self.num_beams = kwargs.pop("num_beams", 1) self.num_beams = kwargs.pop("num_beams", 1)
......
...@@ -609,6 +609,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -609,6 +609,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
self, self,
input_ids=None, input_ids=None,
max_length=None, max_length=None,
min_length=None,
do_sample=True, do_sample=True,
num_beams=None, num_beams=None,
temperature=None, temperature=None,
...@@ -713,6 +714,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -713,6 +714,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
) )
max_length = max_length if max_length is not None else self.config.max_length max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample do_sample = do_sample if do_sample is not None else self.config.do_sample
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature temperature = temperature if temperature is not None else self.config.temperature
...@@ -735,6 +737,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -735,6 +737,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
eos_token_ids = [eos_token_ids] eos_token_ids = [eos_token_ids]
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer." assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert temperature > 0, "`temperature` should be strictly positive." assert temperature > 0, "`temperature` should be strictly positive."
...@@ -824,12 +827,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -824,12 +827,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids encoder_inputs = input_ids
input_ids = torch.full( input_ids = torch.full(
(effective_batch_size * num_beams, 1), (effective_batch_size * num_beams, 1),
# eos_token_id, # eos_token_id, # Why eos_token_id here? bos_token_id makes more sense no?
bos_token_id, bos_token_id,
dtype=torch.long, dtype=torch.long,
device=next(self.parameters()).device, device=next(self.parameters()).device,
) )
cur_len = 0 cur_len = 1
self.model.decoder.generation_mode = True self.model.decoder.generation_mode = True
else: else:
encoder_inputs = None encoder_inputs = None
...@@ -840,6 +843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -840,6 +843,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids, input_ids,
cur_len, cur_len,
max_length, max_length,
min_length,
do_sample, do_sample,
temperature, temperature,
top_k, top_k,
...@@ -859,6 +863,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -859,6 +863,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids, input_ids,
cur_len, cur_len,
max_length, max_length,
min_length,
do_sample, do_sample,
temperature, temperature,
top_k, top_k,
...@@ -877,6 +882,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -877,6 +882,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids, input_ids,
cur_len, cur_len,
max_length, max_length,
min_length,
do_sample, do_sample,
temperature, temperature,
top_k, top_k,
...@@ -911,6 +917,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -911,6 +917,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty) self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
if eos_token_ids is not None and cur_len < min_length:
for eos_token_id in eos_token_ids:
next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks
if do_sample: if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0: if temperature != 1.0:
...@@ -965,6 +975,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -965,6 +975,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids, input_ids,
cur_len, cur_len,
max_length, max_length,
min_length,
do_sample, do_sample,
temperature, temperature,
top_k, top_k,
...@@ -1022,6 +1033,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1022,6 +1033,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_token_logits, batch_size, num_beams, input_ids, repetition_penalty next_token_logits, batch_size, num_beams, input_ids, repetition_penalty
) )
if eos_token_ids is not None and cur_len < min_length:
for eos_token_id in eos_token_ids:
next_token_logits[:, eos_token_id] = -10000.0 # set eos token prob to 0 as is done for attention masks
if do_sample: if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens) # Temperature (higher temperature => more likely to sample low probability tokens)
if temperature != 1.0: if temperature != 1.0:
...@@ -1056,18 +1071,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1056,18 +1071,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if is_encoder_decoder: # TODO(PVP) to be refactored later if is_encoder_decoder: # TODO(PVP) to be refactored later
import math # scores[scores != scores] = -math.inf # block nans => seems very hacky here
# scores[scores != scores] = -math.inf # block nans # scores[:, pad_token_id] = -math.inf => seems very hacky here
# scores[:, pad_token_id] = -math.inf
# TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3 # TODO(SS): fairseq also takes out <unk> every step, and has unk at slot 3
# if cur_len == 0: # Force BOS to be chosen # if cur_len == 0: # Force BOS to be chosen => also very hacky ... seems also to work without this line
# scores[:, self.config.bos_token_id + 1 :] = -math.inf # TODO(PVP) should not use bos_token_id here # scores[:, self.config.bos_token_id + 1 :] = -math.inf
# elif cur_len < min_len: # Prevent EOS from being chosen TODO: for the moment don't think about min_len if cur_len == max_length - 1: # FORCE EOS to be chosen
# scores[:, eos_token_ids[0]] = -math.inf all_but_eos_mask = torch.tensor([x for x in range(vocab_size) if x not in eos_token_ids], dtype=torch.long, device=next(self.parameters()).device)
# elif cur_len == max_length: # FORCE EOS to be chosen scores[:, all_but_eos_mask] = -10000.0
if cur_len == max_length: # FORCE EOS to be chosen
scores[:, :eos_token_ids[0]] = -math.inf
scores[:, eos_token_ids[0] + 1 :] = -math.inf
assert scores.size() == (batch_size * num_beams, vocab_size) assert scores.size() == (batch_size * num_beams, vocab_size)
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product) # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
...@@ -1194,7 +1205,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1194,7 +1205,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# shorter batches are filled with pad_token # shorter batches are filled with pad_token
if sent_lengths.min().item() != sent_lengths.max().item(): if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined" assert pad_token_id is not None, "`Pad_token_id` has to be defined"
sent_max_len = min(sent_lengths.max().item() + 1, max_length + 1) sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id) decoded = input_ids.new(output_batch_size, sent_max_len).fill_(pad_token_id)
# fill with hypothesis and eos_token_id if necessary # fill with hypothesis and eos_token_id if necessary
......
...@@ -442,7 +442,7 @@ class BartModelIntegrationTest(unittest.TestCase): ...@@ -442,7 +442,7 @@ class BartModelIntegrationTest(unittest.TestCase):
tokens = tok.encode(text, return_tensors="pt").to(torch_device) tokens = tok.encode(text, return_tensors="pt").to(torch_device)
extra_len = 20 extra_len = 20
gen_tokens_1 = hf.generate_1(tokens, num_beams=4, max_length=extra_len,) # repetition_penalty=10., gen_tokens_1 = hf.generate_1(tokens, num_beams=4, max_length=extra_len,) # repetition_penalty=10.,
gen_tokens = hf.generate(tokens, num_beams=4, max_length=extra_len, do_sample=False) # repetition_penalty=10., gen_tokens = hf.generate(tokens, num_beams=4, max_length=extra_len + 2, do_sample=False) # repetition_penalty=10.,
print("1: {}".format(gen_tokens_1)) print("1: {}".format(gen_tokens_1))
print("2: {}".format(gen_tokens)) print("2: {}".format(gen_tokens))
ipdb.set_trace() ipdb.set_trace()
......
...@@ -621,7 +621,7 @@ class ModelTesterMixin: ...@@ -621,7 +621,7 @@ class ModelTesterMixin:
with torch.no_grad(): with torch.no_grad():
model(**inputs_dict) model(**inputs_dict)
def _A_test_lm_head_model_random_generate(self): def test_lm_head_model_random_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get( input_ids = inputs_dict.get(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment