train.py 927 Bytes
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
8
9
10
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch

Myle Ott's avatar
Myle Ott committed
11
from fairseq import options
Sergey Edunov's avatar
Sergey Edunov committed
12

Myle Ott's avatar
Myle Ott committed
13
14
15
from distributed_train import main as distributed_main
from multiprocessing_train import main as multiprocessing_main
from singleprocess_train import main as singleprocess_main
Sergey Edunov's avatar
Sergey Edunov committed
16

Myle Ott's avatar
Myle Ott committed
17

Myle Ott's avatar
Myle Ott committed
18
19
20
21
def main(args):
    if args.distributed_port > 0 \
            or args.distributed_init_method is not None:
        distributed_main(args)
Myle Ott's avatar
Myle Ott committed
22
    elif args.distributed_world_size > 1:
Myle Ott's avatar
Myle Ott committed
23
        multiprocessing_main(args)
24
    else:
Myle Ott's avatar
Myle Ott committed
25
        singleprocess_main(args)
Sergey Edunov's avatar
Sergey Edunov committed
26

Myle Ott's avatar
Myle Ott committed
27

Sergey Edunov's avatar
Sergey Edunov committed
28
if __name__ == '__main__':
Myle Ott's avatar
Myle Ott committed
29
30
31
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
    main(args)