train.py 913 Bytes
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
8
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

Myle Ott's avatar
Myle Ott committed
9
from fairseq import options
Sergey Edunov's avatar
Sergey Edunov committed
10

Myle Ott's avatar
Myle Ott committed
11
12
13
from distributed_train import main as distributed_main
from multiprocessing_train import main as multiprocessing_main
from singleprocess_train import main as singleprocess_main
Sergey Edunov's avatar
Sergey Edunov committed
14

Myle Ott's avatar
Myle Ott committed
15

Myle Ott's avatar
Myle Ott committed
16
17
18
19
def main(args):
    if args.distributed_port > 0 \
            or args.distributed_init_method is not None:
        distributed_main(args)
Myle Ott's avatar
Myle Ott committed
20
    elif args.distributed_world_size > 1:
Myle Ott's avatar
Myle Ott committed
21
        multiprocessing_main(args)
22
    else:
Myle Ott's avatar
Myle Ott committed
23
        singleprocess_main(args)
Sergey Edunov's avatar
Sergey Edunov committed
24

Myle Ott's avatar
Myle Ott committed
25

Sergey Edunov's avatar
Sergey Edunov committed
26
if __name__ == '__main__':
Myle Ott's avatar
Myle Ott committed
27
28
29
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
    main(args)