Commit 333affcb authored by Patrick von Platen's avatar Patrick von Platen
Browse files

add current changes

parent 42121699
......@@ -614,6 +614,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length=None,
min_length=None,
do_sample=True,
early_stopping=False,
num_beams=None,
temperature=None,
top_k=None,
......@@ -720,7 +721,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k
......@@ -747,6 +748,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictly positive integer."
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert temperature > 0, "`temperature` should be strictly positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
......@@ -841,8 +843,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
# eos_token_id,
bos_token_id,
eos_token_id,
# bos_token_id,
# eos_token_id, # Why eos_token_id here? bos_token_id seems to work as well ... to see if it works as well with hard summarization case
dtype=torch.long,
device=next(self.parameters()).device,
......@@ -860,6 +862,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length,
min_length,
do_sample,
early_stopping,
temperature,
top_k,
top_p,
......@@ -1012,6 +1015,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
max_length,
min_length,
do_sample,
early_stopping,
temperature,
top_k,
top_p,
......@@ -1033,7 +1037,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# generated hypotheses
generated_hyps = [
BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
BeamHypotheses(num_beams, max_length - 1, length_penalty, early_stopping=early_stopping) for _ in range(batch_size)
]
# scores for each sentence in the beam
......@@ -1080,11 +1084,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# force eos to be chosen at end of generation for encoder-decoder models
# TODO (PVP): both these things are very hacky see whether it might be possible to solve this differently
if self.config.is_encoder_decoder:
# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length)
if cur_len == 1:
self._force_token_ids_generation(next_token_logits, bos_token_id)
if cur_len == max_length - 1:
self._force_token_ids_generation(next_token_logits, eos_token_ids)
# self.prepare_logits_for_softmax(next_token_logits, cur_len, max_length)
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment