Commit e4e0ee14 authored by Rémi Louf's avatar Rémi Louf
Browse files

add separator between data import and train

parent a424892f
......@@ -52,6 +52,10 @@ def set_seed(args):
torch.manual_seed(args.seed)
# ------------
# Load dataset
# ------------
class TextDataset(Dataset):
""" Abstracts the dataset used to train seq2seq models.
......@@ -212,6 +216,11 @@ def load_and_cache_examples(args, tokenizer):
return dataset
# ------------
# Train
# ------------
def train(args, train_dataset, model, tokenizer):
""" Fine-tune the pretrained model on the corpus. """
raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment