train.py 923 Bytes
Newer Older
Myle Ott's avatar
Myle Ott committed
1
#!/usr/bin/env python3 -u
Sergey Edunov's avatar
Sergey Edunov committed
2
3
4
5
6
7
8
9
10
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch

Myle Ott's avatar
Myle Ott committed
11
from fairseq import options
Sergey Edunov's avatar
Sergey Edunov committed
12

Myle Ott's avatar
Myle Ott committed
13
14
15
from distributed_train import main as distributed_main
from multiprocessing_train import main as multiprocessing_main
from singleprocess_train import main as singleprocess_main
Sergey Edunov's avatar
Sergey Edunov committed
16

Myle Ott's avatar
Myle Ott committed
17
18
19
20
21
22
def main(args):
    if args.distributed_port > 0 \
            or args.distributed_init_method is not None:
        distributed_main(args)
    elif torch.cuda.device_count() > 1:
        multiprocessing_main(args)
23
    else:
Myle Ott's avatar
Myle Ott committed
24
        singleprocess_main(args)
Sergey Edunov's avatar
Sergey Edunov committed
25
26

if __name__ == '__main__':
Myle Ott's avatar
Myle Ott committed
27
28
29
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser)
    main(args)