Unverified Commit d0486c8b authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[cleanup] T5 test, warnings (#5761)

parent ec0a945c
...@@ -46,9 +46,7 @@ def generate_summaries_or_translations( ...@@ -46,9 +46,7 @@ def generate_summaries_or_translations(
for batch in tqdm(list(chunks(examples, batch_size))): for batch in tqdm(list(chunks(examples, batch_size))):
if "t5" in model_name: if "t5" in model_name:
batch = [model.config.prefix + text for text in batch] batch = [model.config.prefix + text for text in batch]
batch = tokenizer(batch, max_length=1024, return_tensors="pt", truncation=True, padding="max_length").to( batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device)
device
)
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id) input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs) summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
......
...@@ -2,6 +2,7 @@ import itertools ...@@ -2,6 +2,7 @@ import itertools
import json import json
import os import os
import pickle import pickle
from logging import getLogger
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, Iterable, List from typing import Callable, Dict, Iterable, List
...@@ -181,11 +182,18 @@ class SortishSampler(Sampler): ...@@ -181,11 +182,18 @@ class SortishSampler(Sampler):
return iter(sort_idx) return iter(sort_idx)
logger = getLogger(__name__)
def use_task_specific_params(model, task): def use_task_specific_params(model, task):
# update config with summarization specific params """Update config with summarization specific params."""
task_specific_params = model.config.task_specific_params task_specific_params = model.config.task_specific_params
if task_specific_params is not None: if task_specific_params is not None:
model.config.update(task_specific_params.get(task, {})) pars = task_specific_params.get(task, {})
logger.info(f"using task specific params for {task}: {pars}")
model.config.update(pars)
def pickle_load(path): def pickle_load(path):
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment