"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "f6dc2f67783082b433dfa99d4b0a8992ba64be9d"
Commit c2ee3840 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

update file to new starting token logic

parent 6a82f774
...@@ -20,6 +20,10 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): ...@@ -20,6 +20,10 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
fout = Path(out_file).open("w") fout = Path(out_file).open("w")
model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device) model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device)
tokenizer = BartTokenizer.from_pretrained("bart-large") tokenizer = BartTokenizer.from_pretrained("bart-large")
max_length = 140
min_length = 55
for batch in tqdm(list(chunks(lns, batch_size))): for batch in tqdm(list(chunks(lns, batch_size))):
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
summaries = model.generate( summaries = model.generate(
...@@ -27,11 +31,12 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE): ...@@ -27,11 +31,12 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
attention_mask=dct["attention_mask"].to(device), attention_mask=dct["attention_mask"].to(device),
num_beams=4, num_beams=4,
length_penalty=2.0, length_penalty=2.0,
max_length=142, # +2 from original because we start at step=1 and stop before max_length max_length=max_length + 2, # +2 from original because we start at step=1 and stop before max_length
min_length=56, # +1 from original because we start at step=1 min_length=min_length + 1, # +1 from original because we start at step=1
no_repeat_ngram_size=3, no_repeat_ngram_size=3,
early_stopping=True, early_stopping=True,
do_sample=False, do_sample=False,
decoder_start_token_id=model.config.eos_token_ids[0]
) )
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
for hypothesis in dec: for hypothesis in dec:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment