Commit b3261e7a authored by Rémi Louf's avatar Rémi Louf
Browse files

read parameters from CLI, load model & tokenizer

parent d889e0b7
...@@ -30,12 +30,15 @@ Gao, Ming Zhou, and Hsiao-Wuen Hon. “Unified Language Model Pre-Training for ...@@ -30,12 +30,15 @@ Gao, Ming Zhou, and Hsiao-Wuen Hon. “Unified Language Model Pre-Training for
Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197 Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
""" """
import argparse
import logging import logging
import random import random
import numpy as np import numpy as np
import torch import torch
from transformers import BertConfig, Bert2Rnd, BertTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -43,25 +46,60 @@ def set_seed(args): ...@@ -43,25 +46,60 @@ def set_seed(args):
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def train(args, train_dataset, model, tokenizer): def load_and_cache_examples(args, tokenizer):
""" Fine-tune the pretrained model on the corpus. """
# Data sampler
# Data loader
# Training
raise NotImplementedError raise NotImplementedError
def evaluate(args, model, tokenizer, prefix=""): def train(args, train_dataset, model, tokenizer):
""" Fine-tune the pretrained model on the corpus. """
raise NotImplementedError raise NotImplementedError
def main(): def main():
raise NotImplementedError parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--train_data_file",
default=None,
type=str,
required=True,
help="The input training data file (a text file).")
parser.add_argument("--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.")
# Optional parameters
parser.add_argument("--model_name_or_path",
default="bert-base-cased",
type=str,
help="The model checkpoint for weights initialization.")
parser.add_argument("--seed", default=42, type=int)
args = parser.parse_args()
# Set up training device
device = torch.device("cpu")
# Set seed
set_seed(args)
# Load pretrained model and tokenizer
config_class, model_class, tokenizer_class = BertConfig, Bert2Rnd, BertTokenizer
config = config_class.from_pretrained(args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path, config=config)
model.to(device)
logger.info("Training/evaluation parameters %s", args)
# Training
train_dataset = load_and_cache_examples(args, tokenizer)
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
def __main__(): if __name__ == "__main__":
main() main()
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning seq2seq models for abstractive summarization.
The finetuning method for abstractive summarization is inspired by [1]. We
concatenate the document and summary, mask words of the summary at random and
maximizing the likelihood of masked words.
[1] Dong Li, Nan Yang, Wenhui Wang, Furu Wei, Xiaodong Liu, Yu Wang, Jianfeng
Gao, Ming Zhou, and Hsiao-Wuen Hon. “Unified Language Model Pre-Training for
Natural Language Understanding and Generation.” (May 2019) ArXiv:1905.03197
"""
import logging
import random
import numpy as np
import torch
logger = logging.getLogger(__name__)
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def train(args, train_dataset, model, tokenizer):
raise NotImplementedError
def evaluate(args, model, tokenizer, prefix=""):
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment