Commit 2fbfda0d authored by Myle Ott's avatar Myle Ott
Browse files

Merge internal changes

parent 93fec886
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
from .dictionary import Dictionary from .dictionary import Dictionary
from .fairseq_dataset import FairseqDataset from .fairseq_dataset import FairseqDataset
from .indexed_dataset import IndexedInMemoryDataset, IndexedRawTextDataset from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset # noqa: F401
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset from .token_block_dataset import TokenBlockDataset
......
...@@ -268,16 +268,16 @@ class FConvEncoder(FairseqEncoder): ...@@ -268,16 +268,16 @@ class FConvEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask, # B x T 'encoder_padding_mask': encoder_padding_mask, # B x T
} }
def reorder_encoder_out(self, encoder_out_dict, new_order): def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out_dict['encoder_out'] is not None: if encoder_out['encoder_out'] is not None:
encoder_out_dict['encoder_out'] = ( encoder_out['encoder_out'] = (
encoder_out_dict['encoder_out'][0].index_select(0, new_order), encoder_out['encoder_out'][0].index_select(0, new_order),
encoder_out_dict['encoder_out'][1].index_select(0, new_order), encoder_out['encoder_out'][1].index_select(0, new_order),
) )
if encoder_out_dict['encoder_padding_mask'] is not None: if encoder_out['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order) encoder_out['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict return encoder_out
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
......
...@@ -226,18 +226,18 @@ class FConvEncoder(FairseqEncoder): ...@@ -226,18 +226,18 @@ class FConvEncoder(FairseqEncoder):
'encoder_out': (x, y), 'encoder_out': (x, y),
} }
def reorder_encoder_out(self, encoder_out_dict, new_order): def reorder_encoder_out(self, encoder_out, new_order):
encoder_out_dict['encoder_out'] = tuple( encoder_out['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['encoder_out'] eo.index_select(0, new_order) for eo in encoder_out['encoder_out']
) )
if 'pretrained' in encoder_out_dict: if 'pretrained' in encoder_out:
encoder_out_dict['pretrained']['encoder_out'] = tuple( encoder_out['pretrained']['encoder_out'] = tuple(
eo.index_select(0, new_order) eo.index_select(0, new_order)
for eo in encoder_out_dict['pretrained']['encoder_out'] for eo in encoder_out['pretrained']['encoder_out']
) )
return encoder_out_dict return encoder_out
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
......
...@@ -237,15 +237,15 @@ class LSTMEncoder(FairseqEncoder): ...@@ -237,15 +237,15 @@ class LSTMEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None 'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None
} }
def reorder_encoder_out(self, encoder_out_dict, new_order): def reorder_encoder_out(self, encoder_out, new_order):
encoder_out_dict['encoder_out'] = tuple( encoder_out['encoder_out'] = tuple(
eo.index_select(1, new_order) eo.index_select(1, new_order)
for eo in encoder_out_dict['encoder_out'] for eo in encoder_out['encoder_out']
) )
if encoder_out_dict['encoder_padding_mask'] is not None: if encoder_out['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(1, new_order) encoder_out['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict return encoder_out
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
......
...@@ -225,14 +225,14 @@ class TransformerEncoder(FairseqEncoder): ...@@ -225,14 +225,14 @@ class TransformerEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask, # B x T 'encoder_padding_mask': encoder_padding_mask, # B x T
} }
def reorder_encoder_out(self, encoder_out_dict, new_order): def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out_dict['encoder_out'] is not None: if encoder_out['encoder_out'] is not None:
encoder_out_dict['encoder_out'] = \ encoder_out['encoder_out'] = \
encoder_out_dict['encoder_out'].index_select(1, new_order) encoder_out['encoder_out'].index_select(1, new_order)
if encoder_out_dict['encoder_padding_mask'] is not None: if encoder_out['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \ encoder_out['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order) encoder_out['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict return encoder_out
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
......
...@@ -16,7 +16,7 @@ class FixedSchedule(FairseqLRScheduler): ...@@ -16,7 +16,7 @@ class FixedSchedule(FairseqLRScheduler):
super().__init__(args, optimizer) super().__init__(args, optimizer)
# set defaults # set defaults
args.warmup_updates = getattr(args, 'warmup_updates', 0) args.warmup_updates = getattr(args, 'warmup_updates', 0) or 0
self.lr = args.lr[0] self.lr = args.lr[0]
if args.warmup_updates > 0: if args.warmup_updates > 0:
......
...@@ -62,7 +62,7 @@ def eval_bool(x, default=False): ...@@ -62,7 +62,7 @@ def eval_bool(x, default=False):
return default return default
def parse_args_and_arch(parser, input_args=None): def parse_args_and_arch(parser, input_args=None, parse_known=False):
# The parser doesn't know about model/criterion/optimizer-specific args, so # The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we # we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments. # parse a second time after adding the *-specific arguments.
...@@ -90,7 +90,11 @@ def parse_args_and_arch(parser, input_args=None): ...@@ -90,7 +90,11 @@ def parse_args_and_arch(parser, input_args=None):
TASK_REGISTRY[args.task].add_args(parser) TASK_REGISTRY[args.task].add_args(parser)
# Parse a second time. # Parse a second time.
if parse_known:
args, extra = parser.parse_known_args(input_args)
else:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
extra = None
# Post-process args. # Post-process args.
if hasattr(args, 'lr'): if hasattr(args, 'lr'):
...@@ -104,6 +108,9 @@ def parse_args_and_arch(parser, input_args=None): ...@@ -104,6 +108,9 @@ def parse_args_and_arch(parser, input_args=None):
if hasattr(args, 'arch'): if hasattr(args, 'arch'):
ARCH_CONFIG_REGISTRY[args.arch](args) ARCH_CONFIG_REGISTRY[args.arch](args)
if parse_known:
return args, extra
else:
return args return args
......
...@@ -5,9 +5,6 @@ ...@@ -5,9 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from fairseq import criterions, models
from fairseq.data import FairseqDataset
class FairseqTask(object): class FairseqTask(object):
""" """
...@@ -33,6 +30,7 @@ class FairseqTask(object): ...@@ -33,6 +30,7 @@ class FairseqTask(object):
def dataset(self, split): def dataset(self, split):
"""Return a dataset split.""" """Return a dataset split."""
from fairseq.data import FairseqDataset
if split not in self.datasets: if split not in self.datasets:
raise KeyError('Dataset not loaded: ' + split) raise KeyError('Dataset not loaded: ' + split)
if not isinstance(self.datasets[split], FairseqDataset): if not isinstance(self.datasets[split], FairseqDataset):
...@@ -40,9 +38,11 @@ class FairseqTask(object): ...@@ -40,9 +38,11 @@ class FairseqTask(object):
return self.datasets[split] return self.datasets[split]
def build_model(self, args): def build_model(self, args):
from fairseq import models
return models.build_model(args, self) return models.build_model(args, self)
def build_criterion(self, args): def build_criterion(self, args):
from fairseq import criterions
return criterions.build_criterion(args, self) return criterions.build_criterion(args, self)
def get_loss(self, model, criterion, sample): def get_loss(self, model, criterion, sample):
......
...@@ -140,6 +140,11 @@ class Trainer(object): ...@@ -140,6 +140,11 @@ class Trainer(object):
ooms_fwd = sum(ooms_fwd) ooms_fwd = sum(ooms_fwd)
ooms_bwd = sum(ooms_bwd) ooms_bwd = sum(ooms_bwd)
if ooms_fwd == self.args.distributed_world_size:
print('| WARNING: OOM in all workers, skipping batch')
self.zero_grad()
return None
# aggregate stats and logging outputs # aggregate stats and logging outputs
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
...@@ -178,11 +183,6 @@ class Trainer(object): ...@@ -178,11 +183,6 @@ class Trainer(object):
return None # buffering updates return None # buffering updates
def _forward(self, sample, eval=False): def _forward(self, sample, eval=False):
# prepare model and optimizer
if eval:
self.model.eval()
else:
self.model.train()
loss = None loss = None
sample_size = 0 sample_size = 0
logging_output = { logging_output = {
...@@ -190,8 +190,15 @@ class Trainer(object): ...@@ -190,8 +190,15 @@ class Trainer(object):
'nsentences': sample['target'].size(0) if sample is not None else 0, 'nsentences': sample['target'].size(0) if sample is not None else 0,
} }
oom = 0 oom = 0
if sample is not None:
try: try:
# prepare model and optimizer
if eval:
self.model.eval()
else:
self.model.train()
self.optimizer.zero_grad()
if sample is not None:
with torch.no_grad() if eval else contextlib.ExitStack(): with torch.no_grad() if eval else contextlib.ExitStack():
# calculate loss and sample size # calculate loss and sample size
loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample) loss, sample_size, logging_output_ = self.task.get_loss(self.model, self.criterion, sample)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment