Unverified Commit d0486c8b authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[cleanup] T5 test, warnings (#5761)

parent ec0a945c
......@@ -46,9 +46,7 @@ def generate_summaries_or_translations(
for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name:
batch = [model.config.prefix + text for text in batch]
batch = tokenizer(batch, max_length=1024, return_tensors="pt", truncation=True, padding="max_length").to(
device
)
batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device)
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
......
......@@ -2,6 +2,7 @@ import itertools
import json
import os
import pickle
from logging import getLogger
from pathlib import Path
from typing import Callable, Dict, Iterable, List
......@@ -181,11 +182,18 @@ class SortishSampler(Sampler):
return iter(sort_idx)
logger = getLogger(__name__)
def use_task_specific_params(model, task):
# update config with summarization specific params
"""Update config with summarization specific params."""
task_specific_params = model.config.task_specific_params
if task_specific_params is not None:
model.config.update(task_specific_params.get(task, {}))
pars = task_specific_params.get(task, {})
logger.info(f"using task specific params for {task}: {pars}")
model.config.update(pars)
def pickle_load(path):
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment