Commit 7633129b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#283)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/283

Pull Request resolved: https://github.com/pytorch/fairseq/pull/428

Differential Revision: D13564190

Pulled By: myleott

fbshipit-source-id: 3b62282d7069c288f5bdd1dd2c120788cee4abb5
parent 0cb87130
......@@ -22,6 +22,7 @@ class SequenceGenerator(object):
match_source_len=False, no_repeat_ngram_size=0
):
"""Generates translations of a given source sentence.
Args:
beam_size (int, optional): beam width (default: 1)
min/maxlen (int, optional): the length of the generated output will
......@@ -90,11 +91,14 @@ class SequenceGenerator(object):
cuda=False, timer=None, prefix_size=0,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
cuda: use GPU for generation
timer: StopwatchMeter for timing generations.
maxlen_a/b (int, optional): generate sequences of maximum length
``ax + b``, where ``x`` is the source sentence length.
cuda (bool, optional): use GPU for generation
timer (StopwatchMeter, optional): time generations
prefix_size (int, optional): prefill the generation with the gold
prefix up to this length.
"""
if maxlen_b is None:
maxlen_b = self.maxlen
......@@ -132,12 +136,13 @@ class SequenceGenerator(object):
"""Generate a batch of translations.
Args:
encoder_input: dictionary containing the inputs to
model.encoder.forward
beam_size: int overriding the beam size. defaults to
self.beam_size
max_len: maximum length of the generated sequence
prefix_tokens: force decoder to begin with these tokens
encoder_input (dict): dictionary containing the inputs to
*model.encoder.forward*.
beam_size (int, optional): overriding the beam size
(default: *self.beam_size*).
max_len (int, optional): maximum length of the generated sequence
prefix_tokens (LongTensor, optional): force decoder to begin with
these tokens
"""
with torch.no_grad():
return self._generate(encoder_input, beam_size, maxlen, prefix_tokens)
......
......@@ -87,4 +87,3 @@ class SequenceScorer(object):
index=sample['target'].data.unsqueeze(-1),
)
return avg_probs.squeeze(2), avg_attn
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import argparse
......
......@@ -61,29 +61,32 @@ class FairseqTask(object):
def get_batch_iterator(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0,
seed=1, num_shards=1, shard_id=0, num_workers=0,
):
"""
Get an iterator that yields batches of data from the given dataset.
Args:
dataset (~fairseq.data.FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch.
Default: ``None``
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch. Default: ``None``
batch (default: None).
max_positions (optional): max sentence length supported by the
model. Default: ``None``
model (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long. Default: ``False``
sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1``
be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
Returns:
~fairseq.iterators.EpochBatchIterator: a batched iterator over the
......@@ -114,6 +117,7 @@ class FairseqTask(object):
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
)
def build_model(self, args):
......
......@@ -10,9 +10,15 @@ import numpy as np
import os
from fairseq.data import (
ConcatDataset, Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset, TruncatedDictionary,
IndexedCachedDataset, IndexedDataset)
ConcatDataset,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
MonolingualDataset,
TokenBlockDataset,
TruncatedDictionary,
)
from . import FairseqTask, register_task
......@@ -60,6 +66,8 @@ class LanguageModelingTask(FairseqTask):
'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=1024, type=int,
help='max number of tokens per sample for LM dataset')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument('--output-dictionary-size', default=-1, type=int,
......@@ -139,7 +147,10 @@ class LanguageModelingTask(FairseqTask):
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
elif not self.args.raw_text and IndexedDataset.exists(path):
ds = IndexedDataset(path, fix_lua_indexing=True)
if self.args.lazy_load:
ds = IndexedDataset(path, fix_lua_indexing=True)
else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
else:
if k > 0:
break
......@@ -148,9 +159,11 @@ class LanguageModelingTask(FairseqTask):
loaded_datasets.append(
TokenBlockDataset(
ds, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(),
ds, ds.sizes, self.args.tokens_per_sample,
pad=self.dictionary.pad(), eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode, include_targets=True,
))
)
)
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
......
......@@ -12,8 +12,12 @@ import torch
from fairseq import options
from fairseq.data import (
Dictionary, LanguagePairDataset, IndexedInMemoryDataset,
IndexedRawTextDataset, RoundRobinZipDatasets,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
RoundRobinZipDatasets,
)
from fairseq.models import FairseqMultiModel
......@@ -55,6 +59,8 @@ class MultilingualTranslationTask(FairseqTask):
help='source language (only needed for inference)')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language (only needed for inference)')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
......@@ -112,15 +118,18 @@ class MultilingualTranslationTask(FairseqTask):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename):
elif not self.args.raw_text and IndexedDataset.exists(filename):
return True
return False
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path):
return IndexedInMemoryDataset(path, fix_lua_indexing=True)
elif IndexedDataset.exists(path):
if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None
def sort_lang_pair(lang_pair):
......
......@@ -6,13 +6,17 @@
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os
from fairseq import options, utils
from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, ConcatDataset,
IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset
ConcatDataset,
data_utils,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
)
from . import FairseqTask, register_task
......@@ -49,6 +53,8 @@ class TranslationTask(FairseqTask):
help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
help='target language')
parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
......@@ -132,7 +138,10 @@ class TranslationTask(FairseqTask):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedDataset.exists(path):
return IndexedCachedDataset(path, fix_lua_indexing=True)
if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None
src_datasets = []
......
......@@ -6,10 +6,11 @@
# can be found in the PATENTS file in the same directory.
from collections import Counter
import os, re
from multiprocessing import Pool
import os
import re
import torch
from multiprocessing import Pool
SPACE_NORMALIZER = re.compile(r"\s+")
......@@ -27,7 +28,8 @@ def safe_readline(f):
return f.readline()
except UnicodeDecodeError:
pos -= 1
f.seek(pos) # search where this character begins
f.seek(pos) # search where this character begins
class Tokenizer:
......@@ -41,7 +43,7 @@ class Tokenizer:
end = offset + chunk_size
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
safe_readline(f) # drop first incomplete line
line = f.readline()
while line:
for word in tokenize(line):
......@@ -73,14 +75,17 @@ class Tokenizer:
merge_result(Tokenizer.add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))
@staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line,
append_eos=True, reverse_order=False,
offset=0, end=-1):
def binarize(
filename, dict, consumer, tokenize=tokenize_line, append_eos=True,
reverse_order=False, offset=0, end=-1,
):
nseq, ntok = 0, 0
replaced = Counter()
def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word])
with open(filename, 'r') as f:
f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used
......
......@@ -30,22 +30,23 @@ class Trainer(object):
"""
def __init__(self, args, task, model, criterion, dummy_batch, oom_batch=None):
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
self.args = args
self.task = task
# copy model and criterion to current device
self.criterion = criterion.cuda()
self.criterion = criterion
self._model = model
self.cuda = torch.cuda.is_available() and not args.cpu
if args.fp16:
self._model = model.half().cuda()
else:
self._model = model.cuda()
self._model = self._model.half()
if self.cuda:
self.criterion = self.criterion.cuda()
self._model = self._model.cuda()
self._dummy_batch = dummy_batch
self._oom_batch = oom_batch
self._lr_scheduler = None
self._num_updates = 0
self._optim_history = None
self._optimizer = None
......@@ -71,7 +72,6 @@ class Trainer(object):
self.meters['wall'] = TimeMeter() # wall time in seconds
self.meters['train_wall'] = StopwatchMeter() # train wall time in seconds
@property
def model(self):
if self._wrapped_model is None:
......@@ -89,19 +89,26 @@ class Trainer(object):
self._build_optimizer()
return self._optimizer
@property
def lr_scheduler(self):
if self._lr_scheduler is None:
self._lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
return self._lr_scheduler
def _build_optimizer(self):
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
if self.args.fp16:
if torch.cuda.get_device_capability(0)[0] < 7:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
print('| WARNING: your device does NOT support faster training with --fp16, '
'please switch to FP32 which is likely to be faster')
params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
if self.args.memory_efficient_fp16:
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer(self.args, params)
else:
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
else:
if torch.cuda.get_device_capability(0)[0] >= 7:
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
print('| NOTICE: your device may support faster training with --fp16')
self._optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer)
self._optimizer = optim.build_optimizer(self.args, params)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
......@@ -151,7 +158,8 @@ class Trainer(object):
# reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates()
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if self.cuda:
torch.cuda.manual_seed(seed)
self.model.train()
self.zero_grad()
......@@ -296,7 +304,8 @@ class Trainer(object):
for p in self.model.parameters():
if p.grad is not None:
del p.grad # free some memory
torch.cuda.empty_cache()
if self.cuda:
torch.cuda.empty_cache()
return self.valid_step(sample, raise_oom=True)
else:
raise e
......@@ -377,4 +386,6 @@ class Trainer(object):
def _prepare_sample(self, sample):
if sample is None or len(sample) == 0:
return None
return utils.move_to_cuda(sample)
if self.cuda:
sample = utils.move_to_cuda(sample)
return sample
......@@ -378,6 +378,14 @@ def item(tensor):
return tensor
def clip_grad_norm_(tensor, max_norm):
grad_norm = item(torch.norm(tensor))
if grad_norm > max_norm > 0:
clip_coef = max_norm / (grad_norm + 1e-6)
tensor.mul_(clip_coef)
return grad_norm
def fill_with_neg_inf(t):
"""FP16-compatible function that fills a tensor with -inf."""
return t.float().fill_(float('-inf')).type_as(t)
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.fb.rendezvous.zeus # noqa: F401
from fairseq import options
from train import main
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
......@@ -11,7 +11,7 @@ Translate pre-processed data with a trained model.
import torch
from fairseq import bleu, data, options, progress_bar, tasks, tokenizer, utils
from fairseq import bleu, options, progress_bar, tasks, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer
......@@ -41,7 +41,9 @@ def main(args):
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides))
models, _model_args = utils.load_ensemble_for_inference(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
)
# Optimize ensemble for generation
for model in models:
......@@ -69,6 +71,7 @@ def main(args):
required_batch_size_multiple=8,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
# Initialize generator
......
......@@ -75,8 +75,9 @@ def main(args):
# Load ensemble
print('| loading model(s) from {}'.format(args.path))
model_paths = args.path.split(':')
models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides=eval(args.model_overrides))
models, _model_args = utils.load_ensemble_for_inference(
args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides),
)
# Set dictionaries
tgt_dict = task.target_dictionary
......
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import signal
import torch
from fairseq import distributed_utils, options
from train import main as single_process_main
def main(args):
if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
print('| WARNING: when using --update-freq on a single machine, you '
'will get better performance with --ddp-backend=no_c10d')
mp = torch.multiprocessing.get_context('spawn')
# Create a thread to listen for errors in the child processes.
error_queue = mp.SimpleQueue()
error_handler = ErrorHandler(error_queue)
# Train with multiprocessing.
procs = []
base_rank = args.distributed_rank
for i in range(torch.cuda.device_count()):
args.distributed_rank = base_rank + i
args.device_id = i
procs.append(mp.Process(target=run, args=(args, error_queue, ), daemon=True))
procs[i].start()
error_handler.add_child(procs[i].pid)
for p in procs:
p.join()
def run(args, error_queue):
try:
args.distributed_rank = distributed_utils.distributed_init(args)
single_process_main(args)
except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback
error_queue.put((args.distributed_rank, traceback.format_exc()))
class ErrorHandler(object):
"""A class that listens for exceptions in children processes and propagates
the tracebacks to the parent process."""
def __init__(self, error_queue):
import signal
import threading
self.error_queue = error_queue
self.children_pids = []
self.error_thread = threading.Thread(target=self.error_listener, daemon=True)
self.error_thread.start()
signal.signal(signal.SIGUSR1, self.signal_handler)
def add_child(self, pid):
self.children_pids.append(pid)
def error_listener(self):
(rank, original_trace) = self.error_queue.get()
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)
def signal_handler(self, signalnum, stackframe):
for pid in self.children_pids:
os.kill(pid, signal.SIGINT) # kill children processes
(rank, original_trace) = self.error_queue.get()
msg = "\n\n-- Tracebacks above this line can probably be ignored --\n\n"
msg += original_trace
raise Exception(msg)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
......@@ -18,7 +18,7 @@ import shutil
from fairseq.data import indexed_dataset, dictionary
from fairseq.tokenizer import Tokenizer, tokenize_line
from multiprocessing import Pool, Manager, Process
from multiprocessing import Pool
def get_parser():
......
......@@ -50,6 +50,14 @@ class TestTranslation(unittest.TestCase):
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--fp16'])
generate_main(data_dir)
def test_memory_efficient_fp16(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_memory_efficient_fp16') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--memory-efficient-fp16'])
generate_main(data_dir)
def test_update_freq(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_update_freq') as data_dir:
......@@ -68,8 +76,7 @@ class TestTranslation(unittest.TestCase):
data_dir, 'fconv_iwslt_de_en', ['--max-target-positions', '5'],
)
self.assertTrue(
'skip this example with --skip-invalid-size-inputs-valid-test' \
in str(context.exception)
'skip this example with --skip-invalid-size-inputs-valid-test' in str(context.exception)
)
train_translation_model(
data_dir, 'fconv_iwslt_de_en',
......
......@@ -12,10 +12,6 @@ import os
import tempfile
import unittest
import torch
from fairseq import options
from . import test_binaries
......@@ -79,6 +75,12 @@ class TestReproducibility(unittest.TestCase):
'--fp16-init-scale', '4096',
])
def test_reproducibility_memory_efficient_fp16(self):
self._test_reproducibility('test_reproducibility_memory_efficient_fp16', [
'--memory-efficient-fp16',
'--fp16-init-scale', '4096',
])
if __name__ == '__main__':
unittest.main()
......@@ -39,8 +39,10 @@ def mock_dict():
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size)))
tokens_ds = data.TokenBlockDataset(tokens, sizes=[len(tokens)], block_size=1, pad=0, eos=1, include_targets=False)
tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
tokens_ds = data.TokenBlockDataset(
tokens, sizes=[tokens.size(-1)], block_size=1, pad=0, eos=1, include_targets=False,
)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
epoch_itr = data.EpochBatchIterator(
......@@ -64,7 +66,6 @@ class TestLoadCheckpoint(unittest.TestCase):
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches]
def test_load_partial_checkpoint(self):
with contextlib.redirect_stdout(StringIO()):
trainer, epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
......
......@@ -28,9 +28,8 @@ def main(args):
args.max_tokens = 6000
print(args)
if not torch.cuda.is_available():
raise NotImplementedError('Training on CPU is not supported')
torch.cuda.set_device(args.device_id)
if torch.cuda.is_available() and not args.cpu:
torch.cuda.set_device(args.device_id)
torch.manual_seed(args.seed)
# Setup task, e.g., translation, language modeling, etc.
......@@ -74,6 +73,7 @@ def main(args):
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
)
# Load the latest checkpoint if one is available
......@@ -211,6 +211,7 @@ def validate(args, trainer, task, epoch_itr, subsets):
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch,
......@@ -306,7 +307,15 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
# remove old checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt')
for old_chk in checkpoints[args.keep_interval_updates:]:
os.remove(old_chk)
if os.path.lexists(old_chk):
os.remove(old_chk)
if args.keep_last_epochs > 0:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints = utils.checkpoint_paths(args.save_dir, pattern=r'checkpoint\d+\.pt')
for old_chk in checkpoints[args.keep_last_epochs:]:
if os.path.lexists(old_chk):
os.remove(old_chk)
def load_checkpoint(args, trainer, epoch_itr):
......@@ -346,23 +355,50 @@ def load_dataset_splits(task, splits):
raise e
def distributed_main(i, args):
import socket
args.device_id = i
if args.distributed_rank is None: # torch.multiprocessing.spawn
args.distributed_rank = i
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
main(args)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
if args.distributed_port > 0 or args.distributed_init_method is not None:
from distributed_train import main as distributed_main
if args.distributed_init_method is None:
distributed_utils.infer_init_method(args)
distributed_main(args)
if args.distributed_init_method is not None:
# distributed training
distributed_main(args.device_id, args)
args.distributed_rank = distributed_utils.distributed_init(args)
main(args)
elif args.distributed_world_size > 1:
from multiprocessing_train import main as multiprocessing_main
# Set distributed training parameters for a single node.
args.distributed_world_size = torch.cuda.device_count()
# fallback for single node with multiple GPUs
port = random.randint(10000, 20000)
args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
args.distributed_port = port + 1
multiprocessing_main(args)
args.distributed_rank = None # set based on device id
print(
'''| NOTE: you may get better performance with:
python -m torch.distributed.launch --nproc_per_node {ngpu} train.py {no_c10d}(...)
'''.format(
ngpu=args.distributed_world_size,
no_c10d=(
'--ddp-backend=no_c10d ' if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d'
else ''
),
)
)
torch.multiprocessing.spawn(
fn=distributed_main,
args=(args, ),
nprocs=args.distributed_world_size,
)
else:
# single GPU training
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment