Unverified Commit 2db1e2f4 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[cleanup] remove redundant code in SummarizationDataset (#5119)

parent 5f721ad6
...@@ -13,8 +13,6 @@ from torch import nn ...@@ -13,8 +13,6 @@ from torch import nn
from torch.utils.data import Dataset, Sampler from torch.utils.data import Dataset, Sampler
from tqdm import tqdm from tqdm import tqdm
from transformers import BartTokenizer
def encode_file( def encode_file(
tokenizer, tokenizer,
...@@ -85,7 +83,7 @@ class SummarizationDataset(Dataset): ...@@ -85,7 +83,7 @@ class SummarizationDataset(Dataset):
prefix="", prefix="",
): ):
super().__init__() super().__init__()
tok_name = "T5" if not isinstance(tokenizer, BartTokenizer) else "" tok_name = tokenizer.__class__.__name__.lower().rstrip("tokenizer")
self.source = encode_file( self.source = encode_file(
tokenizer, tokenizer,
os.path.join(data_dir, type_path + ".source"), os.path.join(data_dir, type_path + ".source"),
...@@ -94,16 +92,10 @@ class SummarizationDataset(Dataset): ...@@ -94,16 +92,10 @@ class SummarizationDataset(Dataset):
prefix=prefix, prefix=prefix,
tok_name=tok_name, tok_name=tok_name,
) )
if type_path == "train": tgt_path = os.path.join(data_dir, type_path + ".target")
tgt_path = os.path.join(data_dir, type_path + ".target")
else:
tgt_path = os.path.join(data_dir, type_path + ".target")
self.target = encode_file( self.target = encode_file(
tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name tokenizer, tgt_path, max_target_length, overwrite_cache=overwrite_cache, tok_name=tok_name
) )
self.source = encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length)
self.target = encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length)
if n_obs is not None: if n_obs is not None:
self.source = self.source[:n_obs] self.source = self.source[:n_obs]
self.target = self.target[:n_obs] self.target = self.target[:n_obs]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment