Unverified Commit 2db1e2f4 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[cleanup] remove redundant code in SummarizationDataset (#5119)

parent 5f721ad6
......@@ -13,8 +13,6 @@ from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from transformers import BartTokenizer
def encode_file(
tokenizer,
......@@ -85,7 +83,7 @@ class SummarizationDataset(Dataset):
prefix="",
):
super().__init__()
tok_name = "T5" if not isinstance(tokenizer, BartTokenizer) else ""
tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
self.source = encode_file(
tokenizer,
os.path.join(data_dir, type_path + ".source"),
......@@ -94,16 +92,10 @@ class SummarizationDataset(Dataset):
prefix=prefix,
tok_name=tok_name,
)
if type_path == "train":
tgt_path = os.path.join(data_dir, type_path + ".target")
else:
tgt_path = os.path.join(data_dir, type_path + ".target")
tgt_path = os.path.join(data_dir, type_path + ".target")
self.target = encode_file(
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
)
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
if n_obs is not None:
self.source = self.source[:n_obs]
self.target = self.target[:n_obs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment