"src/vscode:/vscode.git/clone" did not exist on "f0e84612d8535f9c68c377e27fd5fbab3cf64fec"
bertology.py 2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
#!/usr/bin/env python3

import argparse
import logging
from tqdm import trange

import torch
import torch.nn.functional as F
import numpy as np

thomwolf's avatar
thomwolf committed
11
from pytorch_pretrained_bert import BertForSequenceClassification, BertTokenizer
12
13
14
15
16
17
18
19

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

def run_model():
    parser = argparse.ArgumentParser()
thomwolf's avatar
thomwolf committed
20
    parser.add_argument('--model_name_or_path', type=str, default='bert-base-uncased', help='pretrained model name or path to local checkpoint')
21
    parser.add_argument("--seed", type=int, default=42)
thomwolf's avatar
thomwolf committed
22
23
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda", action='store_true', help="Whether not to use CUDA when available")
24
25
26
27
28
    args = parser.parse_args()

    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
thomwolf's avatar
thomwolf committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42

    if args.local_rank == -1 or args.no_cuda:
        args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    logging.basicConfig(level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        args.device, n_gpu, bool(args.local_rank != -1), args.fp16))
43

44
    tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
thomwolf's avatar
thomwolf committed
45
46
    model = BertForSequenceClassification.from_pretrained(args.model_name_or_path)
    model.to(args.device)
47
48
    model.eval()

49
    
50
51
52
53
54

if __name__ == '__main__':
    run_model()