Commit 8cd56e30 authored by thomwolf's avatar thomwolf
Browse files

fix data processing in script

parent 578d23e0
...@@ -58,12 +58,12 @@ class TextDataset(Dataset): ...@@ -58,12 +58,12 @@ class TextDataset(Dataset):
[2] https://github.com/abisee/cnn-dailymail/ [2] https://github.com/abisee/cnn-dailymail/
""" """
def __init_(self, tokenizer_src, tokenizer_tgt, data_dir="", block_size=512): def __init__(self, tokenizer, prefix='train', data_dir="", block_size=512):
assert os.path.isdir(data_dir) assert os.path.isdir(data_dir)
# Load features that have already been computed if present # Load features that have already been computed if present
cached_features_file = os.path.join( cached_features_file = os.path.join(
data_dir, "cached_lm_{}_{}".format(block_size, data_dir) data_dir, "cached_lm_{}_{}".format(block_size, prefix)
) )
if os.path.exists(cached_features_file): if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
...@@ -72,7 +72,7 @@ class TextDataset(Dataset): ...@@ -72,7 +72,7 @@ class TextDataset(Dataset):
return return
logger.info("Creating features from dataset at %s", data_dir) logger.info("Creating features from dataset at %s", data_dir)
self.examples = []
datasets = ["cnn", "dailymail"] datasets = ["cnn", "dailymail"]
for dataset in datasets: for dataset in datasets:
path_to_stories = os.path.join(data_dir, dataset, "stories") path_to_stories = os.path.join(data_dir, dataset, "stories")
...@@ -91,21 +91,17 @@ class TextDataset(Dataset): ...@@ -91,21 +91,17 @@ class TextDataset(Dataset):
except IndexError: # skip ill-formed stories except IndexError: # skip ill-formed stories
continue continue
story = tokenizer_src.convert_tokens_to_ids( story = tokenizer.encode(story)
tokenizer_src.tokenize(story)
)
story_seq = _fit_to_block_size(story, block_size) story_seq = _fit_to_block_size(story, block_size)
summary = tokenizer_tgt.convert_tokens_to_ids( summary = tokenizer.encode(summary)
tokenizer_tgt.tokenize(summary)
)
summary_seq = _fit_to_block_size(summary, block_size) summary_seq = _fit_to_block_size(summary, block_size)
self.examples.append((story_seq, summary_seq)) self.examples.append((story_seq, summary_seq))
logger.info("Saving features into cache file %s", cached_features_file) logger.info("Saving features into cache file %s", cached_features_file)
with open(cached_features_file, "wb") as sink: with open(cached_features_file, "wb") as sink:
pickle.dump(self.examples, sink, protocole=pickle.HIGHEST_PROTOCOL) pickle.dump(self.examples, sink, protocol=pickle.HIGHEST_PROTOCOL)
def __len__(self): def __len__(self):
return len(self.examples) return len(self.examples)
...@@ -169,11 +165,11 @@ def _fit_to_block_size(sequence, block_size): ...@@ -169,11 +165,11 @@ def _fit_to_block_size(sequence, block_size):
if len(sequence) > block_size: if len(sequence) > block_size:
return sequence[:block_size] return sequence[:block_size]
else: else:
return sequence.extend([-1] * [block_size - len(sequence)]) return sequence.extend([-1] * (block_size - len(sequence)))
def load_and_cache_examples(args, tokenizer_src, tokenizer_tgt): def load_and_cache_examples(args, tokenizer):
dataset = TextDataset(tokenizer_src, tokenizer_tgt, file_path=args.data_dir) dataset = TextDataset(tokenizer, data_dir=args.data_dir)
return dataset return dataset
...@@ -293,29 +289,17 @@ def main(): ...@@ -293,29 +289,17 @@ def main():
"--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer." "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
) )
parser.add_argument( parser.add_argument(
"--decoder_name_or_path", "--model_name_or_path",
default="bert-base-cased", default="bert-base-cased",
type=str, type=str,
help="The model checkpoint to initialize the decoder's weights with.", help="The model checkpoint to initialize the encoder and decoder's weights with.",
) )
parser.add_argument( parser.add_argument(
"--decoder_type", "--model_type",
default="bert", default="bert",
type=str, type=str,
help="The decoder architecture to be fine-tuned.", help="The decoder architecture to be fine-tuned.",
) )
parser.add_argument(
"--encoder_name_or_path",
default="bert-base-cased",
type=str,
help="The model checkpoint to initialize the encoder's weights with.",
)
parser.add_argument(
"--encoder_type",
default="bert",
type=str,
help="The encoder architecture to be fine-tuned.",
)
parser.add_argument( parser.add_argument(
"--learning_rate", "--learning_rate",
default=5e-5, default=5e-5,
...@@ -346,7 +330,7 @@ def main(): ...@@ -346,7 +330,7 @@ def main():
) )
args = parser.parse_args() args = parser.parse_args()
if args.encoder_type != "bert" or args.decoder_type != "bert": if args.model_type != "bert":
raise ValueError( raise ValueError(
"Only the BERT architecture is currently supported for seq2seq." "Only the BERT architecture is currently supported for seq2seq."
) )
...@@ -358,11 +342,8 @@ def main(): ...@@ -358,11 +342,8 @@ def main():
set_seed(args) set_seed(args)
# Load pretrained model and tokenizer # Load pretrained model and tokenizer
encoder_tokenizer_class = AutoTokenizer.from_pretrained(args.encoder_name_or_path) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
decoder_tokenizer_class = AutoTokenizer.from_pretrained(args.decoder_name_or_path) model = Model2Model.from_pretrained(args.model_name_or_path)
model = Model2Model.from_pretrained(
args.encoder_name_or_path, args.decoder_name_or_path
)
# model.to(device) # model.to(device)
logger.info("Training/evaluation parameters %s", args) logger.info("Training/evaluation parameters %s", args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment