train.py 2.24 KB
Newer Older
mayp777's avatar
UPDATE  
mayp777 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import pathlib
from argparse import ArgumentParser

import lightning.pytorch as pl
from datamodule import L3DAS22DataModule
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from model import DNNBeamformer, DNNBeamformerLightningModule


def run_train(args):
    pl.seed_everything(1)
    logger = TensorBoardLogger(args.exp_dir)
    callbacks = [
        ModelCheckpoint(
            args.checkpoint_path,
            monitor="val/loss",
            save_top_k=5,
            mode="min",
            save_last=True,
        ),
    ]

    trainer = pl.trainer.trainer.Trainer(
        max_epochs=args.epochs,
        callbacks=callbacks,
        accelerator="gpu",
        devices=args.gpus,
        accumulate_grad_batches=1,
        logger=logger,
        gradient_clip_val=5,
        check_val_every_n_epoch=1,
        num_sanity_val_steps=0,
        log_every_n_steps=1,
    )
    model = DNNBeamformer()
    model_module = DNNBeamformerLightningModule(model)
    data_module = L3DAS22DataModule(dataset_path=args.dataset_path, batch_size=args.batch_size)

    trainer.fit(model_module, datamodule=data_module)


def cli_main():
    parser = ArgumentParser()
    parser.add_argument(
        "--checkpoint-path",
        default=None,
        type=pathlib.Path,
        help="Path to checkpoint to use for evaluation.",
    )
    parser.add_argument(
        "--exp-dir",
        default=pathlib.Path("./exp/"),
        type=pathlib.Path,
        help="Directory to save checkpoints and logs to. (Default: './exp/')",
    )
    parser.add_argument(
        "--dataset-path",
        type=pathlib.Path,
        help="Path to L3DAS22 datasets.",
        required=True,
    )
    parser.add_argument(
        "--batch_size",
        default=4,
        type=int,
        help="Batch size for training. (Default: 4)",
    )
    parser.add_argument(
        "--gpus",
        default=1,
        type=int,
        help="Number of GPUs per node to use for training. (Default: 1)",
    )
    parser.add_argument(
        "--epochs",
        default=100,
        type=int,
        help="Number of epochs to train for. (Default: 100)",
    )
    args = parser.parse_args()
    run_train(args)


if __name__ == "__main__":
    cli_main()