Commit 8cd56e30 authored by thomwolf's avatar thomwolf
Browse files

fix data processing in script

parent 578d23e0
......@@ -58,12 +58,12 @@ class TextDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/
"""
def __init_(self, tokenizer_src, tokenizer_tgt, data_dir="", block_size=512):
def __init__(self, tokenizer, prefix='train', data_dir="", block_size=512):
assert os.path.isdir(data_dir)
# Load features that have already been computed if present
cached_features_file = os.path.join(
data_dir, "cached_lm_{}_{}".format(block_size, data_dir)
data_dir, "cached_lm_{}_{}".format(block_size, prefix)
)
if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
......@@ -72,7 +72,7 @@ class TextDataset(Dataset):
return
logger.info("Creating features from dataset at %s", data_dir)
self.examples = []
datasets = ["cnn", "dailymail"]
for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories")
......@@ -91,21 +91,17 @@ class TextDataset(Dataset):
except IndexError: # skip ill-formed stories
continue
story = tokenizer_src.convert_tokens_to_ids(
tokenizer_src.tokenize(story)
)
story = tokenizer.encode(story)
story_seq = _fit_to_block_size(story, block_size)
summary = tokenizer_tgt.convert_tokens_to_ids(
tokenizer_tgt.tokenize(summary)
)
summary = tokenizer.encode(summary)
summary_seq = _fit_to_block_size(summary, block_size)
self.examples.append((story_seq, summary_seq))
logger.info("Saving features into cache file %s", cached_features_file)
with open(cached_features_file, "wb") as sink:
pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL)
pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL)
def __len__(self):
return len(self.examples)
......@@ -169,11 +165,11 @@ def _fit_to_block_size(sequence, block_size):
if len(sequence) > block_size:
return sequence[:block_size]
else:
return sequence.extend([-1] * [block_size - len(sequence)])
return sequence.extend([-1] * (block_size - len(sequence)))
def load_and_cache_examples(args, tokenizer_src, tokenizer_tgt):
dataset = TextDataset(tokenizer_src, tokenizer_tgt, file_path=args.data_dir)
def load_and_cache_examples(args, tokenizer):
dataset = TextDataset(tokenizer, data_dir=args.data_dir)
return dataset
......@@ -293,29 +289,17 @@ def main():
"--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
)
parser.add_argument(
"--decoder_name_or_path",
"--model_name_or_path",
default="bert-base-cased",
type=str,
help="The model checkpoint to initialize the decoder's weights with.",
help="The model checkpoint to initialize the encoder and decoder's weights with.",
)
parser.add_argument(
"--decoder_type",
"--model_type",
default="bert",
type=str,
help="The decoder architecture to be fine-tuned.",
)
parser.add_argument(
"--encoder_name_or_path",
default="bert-base-cased",
type=str,
help="The model checkpoint to initialize the encoder's weights with.",
)
parser.add_argument(
"--encoder_type",
default="bert",
type=str,
help="The encoder architecture to be fine-tuned.",
)
parser.add_argument(
"--learning_rate",
default=5e-5,
......@@ -346,7 +330,7 @@ def main():
)
args = parser.parse_args()
if args.encoder_type != "bert" or args.decoder_type != "bert":
if args.model_type != "bert":
raise ValueError(
"Only the BERT architecture is currently supported for seq2seq."
)
......@@ -358,11 +342,8 @@ def main():
set_seed(args)
# Load pretrained model and tokenizer
encoder_tokenizer_class = AutoTokenizer.from_pretrained(args.encoder_name_or_path)
decoder_tokenizer_class = AutoTokenizer.from_pretrained(args.decoder_name_or_path)
model = Model2Model.from_pretrained(
args.encoder_name_or_path, args.decoder_name_or_path
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
model = Model2Model.from_pretrained(args.model_name_or_path)
# model.to(device)
logger.info("Training/evaluation parameters %s", args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment