train_ner.py 1018 Bytes
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

from src.config import get_params
from src.utils import init_experiment
from src.dataloader import get_dataloader
from src.model import EntityTagger
from src.trainer import NERTrainer

import torch
import numpy as np
from tqdm import tqdm
import random

def random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

def train_ner(params):
    # initialize experiment
    logger = init_experiment(params, logger_filename=params.logger_filename)

    # dataloader
    dataloader_train, dataloader_dev, dataloader_test = get_dataloader(params.model_name, params.batch_size, params.data_folder)

    # BERT-based NER Tagger
    model = EntityTagger(params)
    model.cuda()

    # trainer
    trainer = NERTrainer(params, model)
    trainer.train(dataloader_train, dataloader_dev, dataloader_test)


if __name__ == "__main__":
    params = get_params()

    random_seed(params.seed)
    train_ner(params)