Unverified Commit 741d48f5 authored by Ashwin Geet D'Sa's avatar Ashwin Geet D'Sa Committed by GitHub
Browse files

Remove max length beam scorer (#11378)



* removed max_len

* removed max_length from BeamSearchScorer

* correct max length

* finish

* del vim

* finish & add test
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent bc2571e6
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import UserDict from collections import UserDict
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -110,6 +111,7 @@ class BeamScorer(ABC): ...@@ -110,6 +111,7 @@ class BeamScorer(ABC):
next_scores: torch.FloatTensor, next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor, next_tokens: torch.LongTensor,
next_indices: torch.LongTensor, next_indices: torch.LongTensor,
max_length: int,
**kwargs **kwargs
) -> torch.LongTensor: ) -> torch.LongTensor:
raise NotImplementedError("This is an abstract method.") raise NotImplementedError("This is an abstract method.")
...@@ -152,15 +154,14 @@ class BeamSearchScorer(BeamScorer): ...@@ -152,15 +154,14 @@ class BeamSearchScorer(BeamScorer):
def __init__( def __init__(
self, self,
batch_size: int, batch_size: int,
max_length: int,
num_beams: int, num_beams: int,
device: torch.device, device: torch.device,
length_penalty: Optional[float] = 1.0, length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False, do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1, num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1, num_beam_groups: Optional[int] = 1,
**kwargs,
): ):
self.max_length = max_length
self.num_beams = num_beams self.num_beams = num_beams
self.device = device self.device = device
self.length_penalty = length_penalty self.length_penalty = length_penalty
...@@ -173,7 +174,6 @@ class BeamSearchScorer(BeamScorer): ...@@ -173,7 +174,6 @@ class BeamSearchScorer(BeamScorer):
self._beam_hyps = [ self._beam_hyps = [
BeamHypotheses( BeamHypotheses(
num_beams=self.num_beams, num_beams=self.num_beams,
max_length=self.max_length,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping, early_stopping=self.do_early_stopping,
) )
...@@ -192,6 +192,13 @@ class BeamSearchScorer(BeamScorer): ...@@ -192,6 +192,13 @@ class BeamSearchScorer(BeamScorer):
f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}." f"has to be divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
) )
if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect."
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
",or `group_beam_search(...)`."
)
@property @property
def is_done(self) -> bool: def is_done(self) -> bool:
return self._done.all() return self._done.all()
...@@ -279,6 +286,7 @@ class BeamSearchScorer(BeamScorer): ...@@ -279,6 +286,7 @@ class BeamSearchScorer(BeamScorer):
final_beam_scores: torch.FloatTensor, final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor, final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor, final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
) -> Tuple[torch.LongTensor]: ) -> Tuple[torch.LongTensor]:
...@@ -316,7 +324,7 @@ class BeamSearchScorer(BeamScorer): ...@@ -316,7 +324,7 @@ class BeamSearchScorer(BeamScorer):
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
# prepare for adding eos # prepare for adding eos
sent_max_len = min(sent_lengths.max().item() + 1, self.max_length) sent_max_len = min(sent_lengths.max().item() + 1, max_length)
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed # shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item(): if sent_lengths.min().item() != sent_lengths.max().item():
...@@ -326,7 +334,7 @@ class BeamSearchScorer(BeamScorer): ...@@ -326,7 +334,7 @@ class BeamSearchScorer(BeamScorer):
# fill with hypotheses and eos_token_id if the latter fits in # fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best): for i, hypo in enumerate(best):
decoded[i, : sent_lengths[i]] = hypo decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < self.max_length: if sent_lengths[i] < max_length:
decoded[i, sent_lengths[i]] = eos_token_id decoded[i, sent_lengths[i]] = eos_token_id
return UserDict( return UserDict(
{ {
...@@ -337,11 +345,10 @@ class BeamSearchScorer(BeamScorer): ...@@ -337,11 +345,10 @@ class BeamSearchScorer(BeamScorer):
class BeamHypotheses: class BeamHypotheses:
def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool): def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool):
""" """
Initialize n-best list of hypotheses. Initialize n-best list of hypotheses.
""" """
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty self.length_penalty = length_penalty
self.early_stopping = early_stopping self.early_stopping = early_stopping
self.num_beams = num_beams self.num_beams = num_beams
......
...@@ -1027,7 +1027,6 @@ class GenerationMixin: ...@@ -1027,7 +1027,6 @@ class GenerationMixin:
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=stopping_criteria.max_length,
num_beams=num_beams, num_beams=num_beams,
device=self.device, device=self.device,
length_penalty=length_penalty, length_penalty=length_penalty,
...@@ -1063,7 +1062,6 @@ class GenerationMixin: ...@@ -1063,7 +1062,6 @@ class GenerationMixin:
raise ValueError("`max_length` needs to be a stopping_criteria for now.") raise ValueError("`max_length` needs to be a stopping_criteria for now.")
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=stopping_criteria.max_length,
num_beams=num_beams, num_beams=num_beams,
device=self.device, device=self.device,
length_penalty=length_penalty, length_penalty=length_penalty,
...@@ -1700,7 +1698,6 @@ class GenerationMixin: ...@@ -1700,7 +1698,6 @@ class GenerationMixin:
>>> # instantiate beam scorer >>> # instantiate beam scorer
>>> beam_scorer = BeamSearchScorer( >>> beam_scorer = BeamSearchScorer(
... batch_size=1, ... batch_size=1,
... max_length=model.config.max_length,
... num_beams=num_beams, ... num_beams=num_beams,
... device=model.device, ... device=model.device,
... ) ... )
...@@ -1756,7 +1753,7 @@ class GenerationMixin: ...@@ -1756,7 +1753,7 @@ class GenerationMixin:
assert ( assert (
num_beams * batch_size == batch_beam_size num_beams * batch_size == batch_beam_size
), "Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ), f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9 beam_scores[:, 1:] = -1e9
...@@ -1792,10 +1789,7 @@ class GenerationMixin: ...@@ -1792,10 +1789,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation. # cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation( next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_logits, cur_len=cur_len, max_length=None
)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores = logits_processor(input_ids, next_token_scores)
...@@ -1861,7 +1855,13 @@ class GenerationMixin: ...@@ -1861,7 +1855,13 @@ class GenerationMixin:
this_peer_finished = True this_peer_finished = True
sequence_outputs = beam_scorer.finalize( sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
) )
if return_dict_in_generate: if return_dict_in_generate:
...@@ -2086,10 +2086,7 @@ class GenerationMixin: ...@@ -2086,10 +2086,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation. # cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation( next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_logits, cur_len=cur_len, max_length=None
)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores = logits_processor(input_ids, next_token_scores)
...@@ -2160,7 +2157,13 @@ class GenerationMixin: ...@@ -2160,7 +2157,13 @@ class GenerationMixin:
this_peer_finished = True this_peer_finished = True
sequence_outputs = beam_scorer.finalize( sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
) )
if return_dict_in_generate: if return_dict_in_generate:
...@@ -2411,10 +2414,7 @@ class GenerationMixin: ...@@ -2411,10 +2414,7 @@ class GenerationMixin:
# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `F.log_softmax` operation. # cannot be generated both before and after the `F.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation( next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
next_token_logits, cur_len=cur_len, max_length=None
)
next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size)
vocab_size = next_token_scores.shape[-1] vocab_size = next_token_scores.shape[-1]
...@@ -2497,7 +2497,13 @@ class GenerationMixin: ...@@ -2497,7 +2497,13 @@ class GenerationMixin:
this_peer_finished = True this_peer_finished = True
sequence_outputs = beam_scorer.finalize( sequence_outputs = beam_scorer.finalize(
input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
) )
if return_dict_in_generate: if return_dict_in_generate:
......
...@@ -1335,7 +1335,7 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1335,7 +1335,7 @@ class MarianMTModel(MarianPreTrainedModel):
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def adjust_logits_during_generation(self, logits, cur_len, max_length): def adjust_logits_during_generation(self, logits, cur_len):
logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token. logits[:, self.config.pad_token_id] = float("-inf") # never predict pad token.
return logits return logits
......
...@@ -1543,7 +1543,6 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1543,7 +1543,6 @@ class RagTokenForGeneration(RagPreTrainedModel):
raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=num_beams, num_beams=num_beams,
device=self.device, device=self.device,
length_penalty=length_penalty, length_penalty=length_penalty,
......
...@@ -59,7 +59,6 @@ class BeamSearchTester: ...@@ -59,7 +59,6 @@ class BeamSearchTester:
def prepare_beam_scorer(self, **kwargs): def prepare_beam_scorer(self, **kwargs):
return BeamSearchScorer( return BeamSearchScorer(
batch_size=kwargs.get("batch_size", self.batch_size), batch_size=kwargs.get("batch_size", self.batch_size),
max_length=kwargs.get("max_length", self.max_length),
num_beams=kwargs.get("num_beams", self.num_beams), num_beams=kwargs.get("num_beams", self.num_beams),
device=torch_device, device=torch_device,
length_penalty=kwargs.get("length_penalty", self.length_penalty), length_penalty=kwargs.get("length_penalty", self.length_penalty),
...@@ -170,9 +169,7 @@ class BeamSearchTester: ...@@ -170,9 +169,7 @@ class BeamSearchTester:
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores): def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
# max_length should be only one more than current input_ids to check that eos is correctly appended # max_length should be only one more than current input_ids to check that eos is correctly appended
max_length = self.sequence_length + 1 max_length = self.sequence_length + 1
beam_scorer = self.prepare_beam_scorer( beam_scorer = self.prepare_beam_scorer(num_beam_hyps_to_keep=1, length_penalty=1.0, do_early_stopping=False)
num_beam_hyps_to_keep=1, max_length=max_length, length_penalty=1.0, do_early_stopping=False
)
# update beams and append to input_ids # update beams and append to input_ids
tokens = next_tokens.clone() tokens = next_tokens.clone()
...@@ -197,6 +194,7 @@ class BeamSearchTester: ...@@ -197,6 +194,7 @@ class BeamSearchTester:
output_indices, output_indices,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
max_length=max_length,
) )
sequences = sequence_output["sequences"] sequences = sequence_output["sequences"]
...@@ -225,6 +223,7 @@ class BeamSearchTester: ...@@ -225,6 +223,7 @@ class BeamSearchTester:
output_indices, output_indices,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
max_length=max_length,
) )
sequences = sequence_output["sequences"] sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"] sequence_scores = sequence_output["sequence_scores"]
......
...@@ -148,7 +148,6 @@ class GenerationTesterMixin: ...@@ -148,7 +148,6 @@ class GenerationTesterMixin:
} }
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"], num_beams=beam_kwargs["num_beams"],
device=torch_device, device=torch_device,
length_penalty=beam_kwargs["length_penalty"], length_penalty=beam_kwargs["length_penalty"],
...@@ -169,7 +168,6 @@ class GenerationTesterMixin: ...@@ -169,7 +168,6 @@ class GenerationTesterMixin:
} }
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=beam_kwargs["num_beams"], num_beams=beam_kwargs["num_beams"],
device=torch_device, device=torch_device,
length_penalty=beam_kwargs["length_penalty"], length_penalty=beam_kwargs["length_penalty"],
...@@ -1411,7 +1409,6 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1411,7 +1409,6 @@ class GenerationIntegrationTests(unittest.TestCase):
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=num_beams, num_beams=num_beams,
device=torch_device, device=torch_device,
) )
...@@ -1442,7 +1439,6 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1442,7 +1439,6 @@ class GenerationIntegrationTests(unittest.TestCase):
diverse_beam_scorer = BeamSearchScorer( diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=num_beams, num_beams=num_beams,
device=torch_device, device=torch_device,
num_beam_hyps_to_keep=num_return_sequences, num_beam_hyps_to_keep=num_return_sequences,
...@@ -1502,7 +1498,6 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1502,7 +1498,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# Beam # Beam
beam_scorer = BeamSearchScorer( beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=num_beams, num_beams=num_beams,
device=torch_device, device=torch_device,
) )
...@@ -1520,7 +1515,6 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1520,7 +1515,6 @@ class GenerationIntegrationTests(unittest.TestCase):
# Grouped beam search # Grouped beam search
diverse_beam_scorer = BeamSearchScorer( diverse_beam_scorer = BeamSearchScorer(
batch_size=batch_size, batch_size=batch_size,
max_length=max_length,
num_beams=num_beams, num_beams=num_beams,
device=torch_device, device=torch_device,
num_beam_hyps_to_keep=num_return_sequences, num_beam_hyps_to_keep=num_return_sequences,
...@@ -1535,3 +1529,51 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1535,3 +1529,51 @@ class GenerationIntegrationTests(unittest.TestCase):
max_length=max_length, max_length=max_length,
**model_kwargs, **model_kwargs,
) )
def test_beam_search_warning_if_max_length_is_passed(self):
article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
bart_tokenizer = BartTokenizer.from_pretrained("sshleifer/bart-tiny-random")
bart_model = BartForConditionalGeneration.from_pretrained("sshleifer/bart-tiny-random").to(torch_device)
batch_size = 1
num_beams = 3
input_ids = bart_tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
input_ids = input_ids.expand(num_beams, -1)
model_kwargs = bart_model._prepare_encoder_decoder_kwargs_for_generation(input_ids, {})
stopping_criteria_max_length = 18
stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=stopping_criteria_max_length)])
with self.assertWarns(UserWarning):
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
max_length=10,
)
generated_ids = bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer,
**model_kwargs,
)
beam_scorer_no_max_len = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=torch_device,
)
generated_ids_no_max_len = bart_model.beam_search(
input_ids,
num_beams=num_beams,
stopping_criteria=stopping_criteria,
beam_scorer=beam_scorer_no_max_len,
**model_kwargs,
)
# BeamSearchScorer max_length should not influence "real" max_length
self.assertEqual(generated_ids.tolist(), generated_ids_no_max_len.tolist())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment