Commit 2fbfda0d authored by Myle Ott's avatar Myle Ott
Browse files

Merge internal changes

parent 93fec886
......@@ -7,7 +7,7 @@
from .dictionary import Dictionary
from .fairseq_dataset import FairseqDataset
from .indexed_dataset import IndexedInMemoryDataset, IndexedRawTextDataset
from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset # noqa: F401
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset
......
......@@ -268,16 +268,16 @@ class FConvEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask, # B x T
}
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_out'] is not None:
encoder_out_dict['encoder_out'] = (
encoder_out_dict['encoder_out'][0].index_select(0, new_order),
encoder_out_dict['encoder_out'][1].index_select(0, new_order),
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out['encoder_out'] is not None:
encoder_out['encoder_out'] = (
encoder_out['encoder_out'][0].index_select(0, new_order),
encoder_out['encoder_out'][1].index_select(0, new_order),
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = \
encoder_out['encoder_padding_mask'].index_select(0, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
......
......@@ -226,18 +226,18 @@ class FConvEncoder(FairseqEncoder):
'encoder_out': (x, y),
}
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['encoder_out']
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out['encoder_out']
)
if 'pretrained' in encoder_out_dict:
encoder_out_dict['pretrained']['encoder_out'] = tuple(
if 'pretrained' in encoder_out:
encoder_out['pretrained']['encoder_out'] = tuple(
eo.index_select(0, new_order)
for eo in encoder_out_dict['pretrained']['encoder_out']
for eo in encoder_out['pretrained']['encoder_out']
)
return encoder_out_dict
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
......
......@@ -237,15 +237,15 @@ class LSTMEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None
}
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
def reorder_encoder_out(self, encoder_out, new_order):
encoder_out['encoder_out'] = tuple(
eo.index_select(1, new_order)
for eo in encoder_out_dict['encoder_out']
for eo in encoder_out['encoder_out']
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = \
encoder_out['encoder_padding_mask'].index_select(1, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
......
......@@ -225,14 +225,14 @@ class TransformerEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask, # B x T
}
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_out'] is not None:
encoder_out_dict['encoder_out'] = \
encoder_out_dict['encoder_out'].index_select(1, new_order)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out['encoder_out'] is not None:
encoder_out['encoder_out'] = \
encoder_out['encoder_out'].index_select(1, new_order)
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = \
encoder_out['encoder_padding_mask'].index_select(0, new_order)
return encoder_out
def max_positions(self):
"""Maximum input length supported by the encoder."""
......
......@@ -16,7 +16,7 @@ class FixedSchedule(FairseqLRScheduler):
super().__init__(args, optimizer)
# set defaults
args.warmup_updates = getattr(args, 'warmup_updates', 0)
args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0
self.lr = args.lr[0]
if args.warmup_updates > 0:
......
......@@ -62,7 +62,7 @@ def eval_bool(x, default=False):
return default
def parse_args_and_arch(parser, input_args=None):
def parse_args_and_arch(parser, input_args=None, parse_known=False):
# The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments.
......@@ -90,7 +90,11 @@ def parse_args_and_arch(parser, input_args=None):
TASK_REGISTRY[args.task].add_args(parser)
# Parse a second time.
args = parser.parse_args(input_args)
if parse_known:
args, extra = parser.parse_known_args(input_args)
else:
args = parser.parse_args(input_args)
extra = None
# Post-process args.
if hasattr(args, 'lr'):
......@@ -104,7 +108,10 @@ def parse_args_and_arch(parser, input_args=None):
if hasattr(args, 'arch'):
ARCH_CONFIG_REGISTRY[args.arch](args)
return args
if parse_known:
return args, extra
else:
return args
def get_parser(desc, default_task='translation'):
......
......@@ -5,9 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq import criterions, models
from fairseq.data import FairseqDataset
class FairseqTask(object):
"""
......@@ -33,6 +30,7 @@ class FairseqTask(object):
def dataset(self, split):
"""Return a dataset split."""
from fairseq.data import FairseqDataset
if split not in self.datasets:
raise KeyError('Dataset not loaded: ' + split)
if not isinstance(self.datasets[split], FairseqDataset):
......@@ -40,9 +38,11 @@ class FairseqTask(object):
return self.datasets[split]
def build_model(self, args):
from fairseq import models
return models.build_model(args, self)
def build_criterion(self, args):
from fairseq import criterions
return criterions.build_criterion(args, self)
def get_loss(self, model, criterion, sample):
......
......@@ -140,6 +140,11 @@ class Trainer(object):
ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd)
if ooms_fwd == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping batch')
self.zero_grad()
return None
# aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
......@@ -178,11 +183,6 @@ class Trainer(object):
return None # buffering updates
def _forward(self, sample, eval=False):
# prepare model and optimizer
if eval:
self.model.eval()
else:
self.model.train()
loss = None
sample_size = 0
logging_output = {
......@@ -190,19 +190,26 @@ class Trainer(object):
'nsentences': sample['target'].size(0) if sample is not None else 0,
}
oom = 0
if sample is not None:
try:
try:
# prepare model and optimizer
if eval:
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
if sample is not None:
with torch.no_grad() if eval else contextlib.ExitStack():
# calculate loss and sample size
loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
logging_output.update(logging_output_)
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
oom = 1
loss = None
else:
raise e
except RuntimeError as e:
if not eval and 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
oom = 1
loss = None
else:
raise e
return loss, sample_size, logging_output, oom
def _backward(self, loss):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment