Unverified Commit 3814e167 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #3225 from patrickvonplaten/finalize_merge_bart_generate_into_default_generate

Complete merge Seq-2-Seq generation into default generation
parents 2bd79e23 4f75d380
...@@ -20,6 +20,10 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): ...@@ -20,6 +20,10 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
fout = Path(out_file).open("w") fout = Path(out_file).open("w")
model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device) model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device)
tokenizer = BartTokenizer.from_pretrained("bart-large") tokenizer = BartTokenizer.from_pretrained("bart-large")
max_length = 140
min_length = 55
for batch in tqdm(list(chunks(lns, batch_size))): for batch in tqdm(list(chunks(lns, batch_size))):
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
summaries = model.generate( summaries = model.generate(
...@@ -27,11 +31,12 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): ...@@ -27,11 +31,12 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
attention_mask=dct["attention_mask"].to(device), attention_mask=dct["attention_mask"].to(device),
num_beams=4, num_beams=4,
length_penalty=2.0, length_penalty=2.0,
max_length=142, # +2 from original because we start at step=1 and stop before max_length max_length=max_length + 2, # +2 from original because we start at step=1 and stop before max_length
min_length=56, # +1 from original because we start at step=1 min_length=min_length + 1, # +1 from original because we start at step=1
no_repeat_ngram_size=3, no_repeat_ngram_size=3,
early_stopping=True, early_stopping=True,
do_sample=False, do_sample=False,
decoder_start_token_id=model.config.eos_token_ids[0],
) )
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
for hypothesis in dec: for hypothesis in dec:
......
...@@ -628,6 +628,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -628,6 +628,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
no_repeat_ngram_size=None, no_repeat_ngram_size=None,
num_return_sequences=None, num_return_sequences=None,
attention_mask=None, attention_mask=None,
decoder_start_token_id=None,
): ):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search. and beam-search.
...@@ -739,6 +740,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -739,6 +740,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_return_sequences = ( num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
) )
# TODO: think about how to make this cleaner
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.bos_token_id
)
if input_ids is not None: if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size batch_size = input_ids.shape[0] # overriden by the input batch_size
...@@ -765,6 +770,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -765,6 +770,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert (eos_token_ids is None) or ( assert (eos_token_ids is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids) isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers." ), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert length_penalty > 0, "`length_penalty` should be strictly positive." assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert ( assert (
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0 isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
...@@ -845,7 +853,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -845,7 +853,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids encoder_inputs = input_ids
input_ids = torch.full( input_ids = torch.full(
(effective_batch_size * num_beams, 1), (effective_batch_size * num_beams, 1),
bos_token_id, decoder_start_token_id, # TODO: see whether this is the best result
dtype=torch.long, dtype=torch.long,
device=next(self.parameters()).device, device=next(self.parameters()).device,
) )
...@@ -1082,7 +1090,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1082,7 +1090,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if self.config.is_encoder_decoder and do_sample is False: if self.config.is_encoder_decoder and do_sample is False:
# TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here # TODO: maybe give better naming
scores = self.prepare_scores_for_generation(scores, cur_len, max_length) scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
...@@ -1276,7 +1284,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1276,7 +1284,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# do not return first <EOS> token
return decoded[:, 1:] return decoded[:, 1:]
return decoded return decoded
......
...@@ -61,7 +61,7 @@ class ModelTester: ...@@ -61,7 +61,7 @@ class ModelTester:
self.hidden_dropout_prob = 0.1 self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1 self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 20 self.max_position_embeddings = 20
self.eos_token_id = 2 self.eos_token_ids = [2]
self.pad_token_id = 1 self.pad_token_id = 1
self.bos_token_id = 0 self.bos_token_id = 0
torch.manual_seed(0) torch.manual_seed(0)
...@@ -82,7 +82,7 @@ class ModelTester: ...@@ -82,7 +82,7 @@ class ModelTester:
dropout=self.hidden_dropout_prob, dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob, attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
eos_token_ids=[self.eos_token_id], eos_token_ids=[2],
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
) )
...@@ -438,7 +438,11 @@ class BartModelIntegrationTest(unittest.TestCase): ...@@ -438,7 +438,11 @@ class BartModelIntegrationTest(unittest.TestCase):
tokens = tok.encode(text, return_tensors="pt").to(torch_device) tokens = tok.encode(text, return_tensors="pt").to(torch_device)
extra_len = 20 extra_len = 20
gen_tokens = hf.generate( gen_tokens = hf.generate(
tokens, num_beams=4, max_length=extra_len + 2, do_sample=False tokens,
num_beams=4,
max_length=extra_len + 2,
do_sample=False,
decoder_start_token_id=hf.config.eos_token_ids[0],
) # repetition_penalty=10., ) # repetition_penalty=10.,
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday." expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
generated = [tok.decode(g,) for g in gen_tokens] generated = [tok.decode(g,) for g in gen_tokens]
...@@ -483,6 +487,7 @@ class BartModelIntegrationTest(unittest.TestCase): ...@@ -483,6 +487,7 @@ class BartModelIntegrationTest(unittest.TestCase):
no_repeat_ngram_size=3, no_repeat_ngram_size=3,
do_sample=False, do_sample=False,
early_stopping=True, early_stopping=True,
decoder_start_token_id=hf.config.eos_token_ids[0],
) )
decoded = [ decoded = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment