Commit ba989ed1 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Small features + lint

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/588

Differential Revision: D15389638

Pulled By: myleott

fbshipit-source-id: 4632ce22d51dc2c74d250bae999630095d849701
parent 3bfbb49b
...@@ -103,6 +103,7 @@ def load_checkpoint(args, trainer): ...@@ -103,6 +103,7 @@ def load_checkpoint(args, trainer):
args.reset_optimizer, args.reset_optimizer,
args.reset_lr_scheduler, args.reset_lr_scheduler,
eval(args.optimizer_overrides), eval(args.optimizer_overrides),
reset_meters=args.reset_meters,
) )
if extra_state is not None and 'best' in extra_state and not args.reset_optimizer: if extra_state is not None and 'best' in extra_state and not args.reset_optimizer:
......
...@@ -50,11 +50,12 @@ class ConcatDataset(FairseqDataset): ...@@ -50,11 +50,12 @@ class ConcatDataset(FairseqDataset):
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return all([d.supports_prefetch for d in self.datasets]) return any(getattr(d, 'supports_prefetch', False) for d in self.datasets)
def prefetch(self, indices): def prefetch(self, indices):
frm = 0 frm = 0
for to, ds in zip(self.cumulative_sizes, self.datasets): for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds) real_size = len(ds)
if getattr(ds, 'supports_prefetch', False):
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to frm = to
...@@ -21,7 +21,7 @@ def infer_language_pair(path): ...@@ -21,7 +21,7 @@ def infer_language_pair(path):
return src, dst return src, dst
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False): def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
"""Convert a list of 1d tensors into a padded 2d tensor.""" """Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values) size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx) res = values[0].new(len(values), size).fill_(pad_idx)
......
...@@ -21,7 +21,7 @@ class FairseqDataset(torch.utils.data.Dataset): ...@@ -21,7 +21,7 @@ class FairseqDataset(torch.utils.data.Dataset):
"""Merge a list of samples to form a mini-batch. """Merge a list of samples to form a mini-batch.
Args: Args:
samples (List[int]): sample indices to collate samples (List[dict]): samples to collate
Returns: Returns:
dict: a mini-batch suitable for forwarding with a Model dict: a mini-batch suitable for forwarding with a Model
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import os import os
import shutil import shutil
import struct import struct
...@@ -11,6 +12,8 @@ import struct ...@@ -11,6 +12,8 @@ import struct
import numpy as np import numpy as np
import torch import torch
from . import FairseqDataset
def make_builder(out_file, impl): def make_builder(out_file, impl):
if impl == 'mmap': if impl == 'mmap':
...@@ -78,7 +81,7 @@ def data_file_path(prefix_path): ...@@ -78,7 +81,7 @@ def data_file_path(prefix_path):
return prefix_path + '.bin' return prefix_path + '.bin'
class IndexedDataset(torch.utils.data.Dataset): class IndexedDataset(FairseqDataset):
"""Loader for TorchNet IndexedDataset""" """Loader for TorchNet IndexedDataset"""
def __init__(self, path, fix_lua_indexing=False): def __init__(self, path, fix_lua_indexing=False):
...@@ -99,16 +102,16 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -99,16 +102,16 @@ class IndexedDataset(torch.utils.data.Dataset):
assert struct.unpack('<Q', version) == (1,) assert struct.unpack('<Q', version) == (1,)
code, self.element_size = struct.unpack('<QQ', f.read(16)) code, self.element_size = struct.unpack('<QQ', f.read(16))
self.dtype = dtypes[code] self.dtype = dtypes[code]
self.size, self.s = struct.unpack('<QQ', f.read(16)) self._len, self.s = struct.unpack('<QQ', f.read(16))
self.dim_offsets = read_longs(f, self.size + 1) self.dim_offsets = read_longs(f, self._len + 1)
self.data_offsets = read_longs(f, self.size + 1) self.data_offsets = read_longs(f, self._len + 1)
self.sizes = read_longs(f, self.s) self.sizes = read_longs(f, self.s)
def read_data(self, path): def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb', buffering=0) self.data_file = open(data_file_path(path), 'rb', buffering=0)
def check_index(self, i): def check_index(self, i):
if i < 0 or i >= self.size: if i < 0 or i >= self._len:
raise IndexError('index out of range') raise IndexError('index out of range')
def __del__(self): def __del__(self):
...@@ -129,7 +132,13 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -129,7 +132,13 @@ class IndexedDataset(torch.utils.data.Dataset):
return item return item
def __len__(self): def __len__(self):
return self.size return self._len
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod @staticmethod
def exists(path): def exists(path):
...@@ -189,7 +198,7 @@ class IndexedCachedDataset(IndexedDataset): ...@@ -189,7 +198,7 @@ class IndexedCachedDataset(IndexedDataset):
return item return item
class IndexedRawTextDataset(torch.utils.data.Dataset): class IndexedRawTextDataset(FairseqDataset):
"""Takes a text file as input and binarizes it in memory at instantiation. """Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory""" Original lines are also kept in memory"""
...@@ -232,6 +241,12 @@ class IndexedRawTextDataset(torch.utils.data.Dataset): ...@@ -232,6 +241,12 @@ class IndexedRawTextDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return self.size return self.size
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod @staticmethod
def exists(path): def exists(path):
return os.path.exists(path) return os.path.exists(path)
......
...@@ -68,7 +68,7 @@ class BaseFairseqModel(nn.Module): ...@@ -68,7 +68,7 @@ class BaseFairseqModel(nn.Module):
this additionally "upgrades" *state_dicts* from old checkpoints. this additionally "upgrades" *state_dicts* from old checkpoints.
""" """
self.upgrade_state_dict(state_dict) self.upgrade_state_dict(state_dict)
super().load_state_dict(state_dict, strict) return super().load_state_dict(state_dict, strict)
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
"""Upgrade old state dicts to work with newer code.""" """Upgrade old state dicts to work with newer code."""
......
...@@ -330,6 +330,8 @@ def add_checkpoint_args(parser): ...@@ -330,6 +330,8 @@ def add_checkpoint_args(parser):
help='if set, does not load lr scheduler state from the checkpoint') help='if set, does not load lr scheduler state from the checkpoint')
group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT', group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override optimizer args when loading a checkpoint') help='a dictionary used to override optimizer args when loading a checkpoint')
group.add_argument('--reset-meters', action='store_true',
help='if set, does not load meters from the checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N', group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs') help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N', group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
......
...@@ -129,7 +129,14 @@ class Trainer(object): ...@@ -129,7 +129,14 @@ class Trainer(object):
self._optim_history, extra_state, self._optim_history, extra_state,
) )
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None): def load_checkpoint(
self,
filename,
reset_optimizer=False,
reset_lr_scheduler=False,
optimizer_overrides=None,
reset_meters=False,
):
"""Load all training state from a checkpoint file.""" """Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = None, [], None extra_state, self._optim_history, last_optim_state = None, [], None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment