train_cli.py 748 Bytes
Newer Older
Jinjing Zhou's avatar
Jinjing Zhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from ..utils.factory import ModelFactory, PipelineFactory
from ..utils.enter_config import UserConfig
import typer
from enum import Enum
import typing
import yaml
from pathlib import Path

import isort
import autopep8

def train(
    cfg: str = typer.Option("cfg.yml", help="config yaml file name"),
):
    user_cfg = yaml.safe_load(Path(cfg).open("r"))
    pipeline_name = user_cfg["pipeline_name"]
    output_file_content = PipelineFactory.registry[pipeline_name].gen_script(user_cfg)

    f_code = autopep8.fix_code(output_file_content, options={'aggressive': 1})
    f_code = isort.code(f_code)
    exec(f_code,  {'__name__': '__main__'})


if __name__ == "__main__":
    train_app = typer.Typer()
    train_app.command()(train)
    train_app()