Commit ba989ed1 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Small features + lint

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/588

Differential Revision: D15389638

Pulled By: myleott

fbshipit-source-id: 4632ce22d51dc2c74d250bae999630095d849701
parent 3bfbb49b
......@@ -103,6 +103,7 @@ def load_checkpoint(args, trainer):
args.reset_optimizer,
args.reset_lr_scheduler,
eval(args.optimizer_overrides),
reset_meters=args.reset_meters,
)
if extra_state is not None and 'best' in extra_state and not args.reset_optimizer:
......
......@@ -50,11 +50,12 @@ class ConcatDataset(FairseqDataset):
@property
def supports_prefetch(self):
return all([d.supports_prefetch for d in self.datasets])
return any(getattr(d, 'supports_prefetch', False) for d in self.datasets)
def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds)
if getattr(ds, 'supports_prefetch', False):
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to
......@@ -21,7 +21,7 @@ def infer_language_pair(path):
return src, dst
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
......
......@@ -21,7 +21,7 @@ class FairseqDataset(torch.utils.data.Dataset):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[int]): sample indices to collate
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
......
......@@ -4,6 +4,7 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import shutil
import struct
......@@ -11,6 +12,8 @@ import struct
import numpy as np
import torch
from . import FairseqDataset
def make_builder(out_file, impl):
if impl == 'mmap':
......@@ -78,7 +81,7 @@ def data_file_path(prefix_path):
return prefix_path + '.bin'
class IndexedDataset(torch.utils.data.Dataset):
class IndexedDataset(FairseqDataset):
"""Loader for TorchNet IndexedDataset"""
def __init__(self, path, fix_lua_indexing=False):
......@@ -99,16 +102,16 @@ class IndexedDataset(torch.utils.data.Dataset):
assert struct.unpack('<Q', version) == (1,)
code, self.element_size = struct.unpack('<QQ', f.read(16))
self.dtype = dtypes[code]
self.size, self.s = struct.unpack('<QQ', f.read(16))
self.dim_offsets = read_longs(f, self.size + 1)
self.data_offsets = read_longs(f, self.size + 1)
self._len, self.s = struct.unpack('<QQ', f.read(16))
self.dim_offsets = read_longs(f, self._len + 1)
self.data_offsets = read_longs(f, self._len + 1)
self.sizes = read_longs(f, self.s)
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb', buffering=0)
def check_index(self, i):
if i < 0 or i >= self.size:
if i < 0 or i >= self._len:
raise IndexError('index out of range')
def __del__(self):
......@@ -129,7 +132,13 @@ class IndexedDataset(torch.utils.data.Dataset):
return item
def __len__(self):
return self.size
return self._len
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod
def exists(path):
......@@ -189,7 +198,7 @@ class IndexedCachedDataset(IndexedDataset):
return item
class IndexedRawTextDataset(torch.utils.data.Dataset):
class IndexedRawTextDataset(FairseqDataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
......@@ -232,6 +241,12 @@ class IndexedRawTextDataset(torch.utils.data.Dataset):
def __len__(self):
return self.size
def num_tokens(self, index):
return self.sizes[index]
def size(self, index):
return self.sizes[index]
@staticmethod
def exists(path):
return os.path.exists(path)
......
......@@ -68,7 +68,7 @@ class BaseFairseqModel(nn.Module):
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
self.upgrade_state_dict(state_dict)
super().load_state_dict(state_dict, strict)
return super().load_state_dict(state_dict, strict)
def upgrade_state_dict(self, state_dict):
"""Upgrade old state dicts to work with newer code."""
......
......@@ -330,6 +330,8 @@ def add_checkpoint_args(parser):
help='if set, does not load lr scheduler state from the checkpoint')
group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override optimizer args when loading a checkpoint')
group.add_argument('--reset-meters', action='store_true',
help='if set, does not load meters from the checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
......
......@@ -129,7 +129,14 @@ class Trainer(object):
self._optim_history, extra_state,
)
def load_checkpoint(self, filename, reset_optimizer=False, reset_lr_scheduler=False, optimizer_overrides=None):
def load_checkpoint(
self,
filename,
reset_optimizer=False,
reset_lr_scheduler=False,
optimizer_overrides=None,
reset_meters=False,
):
"""Load all training state from a checkpoint file."""
extra_state, self._optim_history, last_optim_state = None, [], None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment