Commit bd4db8fb authored by Myle Ott's avatar Myle Ott
Browse files

Misc changes for pytorch-translate

parent c6fe9fc5
...@@ -106,7 +106,7 @@ class Dictionary(object): ...@@ -106,7 +106,7 @@ class Dictionary(object):
multiple of 8, which is important on some hardware (e.g., Nvidia multiple of 8, which is important on some hardware (e.g., Nvidia
Tensor Cores). Tensor Cores).
""" """
if nwords == -1: if nwords <= 0:
nwords = len(self) nwords = len(self)
new_indices = dict(zip(self.symbols[:self.nspecial], range(self.nspecial))) new_indices = dict(zip(self.symbols[:self.nspecial], range(self.nspecial)))
...@@ -133,7 +133,7 @@ class Dictionary(object): ...@@ -133,7 +133,7 @@ class Dictionary(object):
i += 1 i += 1
threshold_nwords += 1 threshold_nwords += 1
assert min(new_count[self.nspecial:]) >= threshold assert len(new_count) == self.nspecial or min(new_count[self.nspecial:]) >= threshold
assert len(new_symbols) % padding_factor == 0 assert len(new_symbols) % padding_factor == 0
assert len(new_symbols) == len(new_indices) assert len(new_symbols) == len(new_indices)
...@@ -187,12 +187,12 @@ class Dictionary(object): ...@@ -187,12 +187,12 @@ class Dictionary(object):
d.count.append(count) d.count.append(count)
return d return d
def save(self, f, threshold=3, nwords=-1): def save(self, f):
"""Stores dictionary into a text file""" """Stores dictionary into a text file"""
if isinstance(f, str): if isinstance(f, str):
os.makedirs(os.path.dirname(f), exist_ok=True) os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd: with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd, threshold, nwords) return self.save(fd)
for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]): for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]):
print('{} {}'.format(symbol, count), file=f) print('{} {}'.format(symbol, count), file=f)
......
...@@ -52,8 +52,9 @@ def data_file_path(prefix_path): ...@@ -52,8 +52,9 @@ def data_file_path(prefix_path):
class IndexedDataset(torch.utils.data.Dataset): class IndexedDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset""" """Loader for TorchNet IndexedDataset"""
def __init__(self, path): def __init__(self, path, fix_lua_indexing=False):
super().__init__() super().__init__()
self.fix_lua_indexing = fix_lua_indexing
with open(index_file_path(path), 'rb') as f: with open(index_file_path(path), 'rb') as f:
magic = f.read(8) magic = f.read(8)
assert magic == b'TNTIDX\x00\x00' assert magic == b'TNTIDX\x00\x00'
...@@ -83,7 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -83,7 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset):
a = np.empty(tensor_size, dtype=self.dtype) a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a) self.data_file.readinto(a)
return torch.from_numpy(a).long() - 1 # subtract 1 for 0-based indexing item = torch.from_numpy(a).long()
if self.fix_lua_indexing:
item -= 1 # subtract 1 for 0-based indexing
return item
def __len__(self): def __len__(self):
return self.size return self.size
...@@ -104,6 +108,7 @@ class IndexedInMemoryDataset(IndexedDataset): ...@@ -104,6 +108,7 @@ class IndexedInMemoryDataset(IndexedDataset):
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype) self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer) self.data_file.readinto(self.buffer)
self.data_file.close() self.data_file.close()
if self.fix_lua_indexing:
self.buffer -= 1 # subtract 1 for 0-based indexing self.buffer -= 1 # subtract 1 for 0-based indexing
def __del__(self): def __del__(self):
......
...@@ -73,7 +73,7 @@ class FP16Trainer(Trainer): ...@@ -73,7 +73,7 @@ class FP16Trainer(Trainer):
self.fp32_params.grad = self.fp32_params.data.new(total_param_size) self.fp32_params.grad = self.fp32_params.data.new(total_param_size)
# create optimizer using the copied FP32 params # create optimizer using the copied FP32 params
self.optimizer = optim.build_optimizer(self.args, [self.fp32_params]) self._optimizer = optim.build_optimizer(self.args, [self.fp32_params])
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
def save_checkpoint(self, filename, extra_state): def save_checkpoint(self, filename, extra_state):
......
...@@ -15,6 +15,9 @@ class FixedSchedule(FairseqLRScheduler): ...@@ -15,6 +15,9 @@ class FixedSchedule(FairseqLRScheduler):
def __init__(self, args, optimizer): def __init__(self, args, optimizer):
super().__init__(args, optimizer) super().__init__(args, optimizer)
# set defaults
args.warmup_updates = getattr(args, 'warmup_updates', 0)
self.lr = args.lr[0] self.lr = args.lr[0]
if args.warmup_updates > 0: if args.warmup_updates > 0:
self.warmup_factor = 1. / args.warmup_updates self.warmup_factor = 1. / args.warmup_updates
......
...@@ -50,7 +50,7 @@ class LanguageModelingTask(FairseqTask): ...@@ -50,7 +50,7 @@ class LanguageModelingTask(FairseqTask):
ds = IndexedRawTextDataset(path, self.dictionary) ds = IndexedRawTextDataset(path, self.dictionary)
tokens = ds.tokens_list tokens = ds.tokens_list
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path): elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path) ds = IndexedInMemoryDataset(path, fix_lua_indexing=True)
tokens = ds.buffer tokens = ds.buffer
else: else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
......
...@@ -89,7 +89,7 @@ class TranslationTask(FairseqTask): ...@@ -89,7 +89,7 @@ class TranslationTask(FairseqTask):
if self.args.raw_text: if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary) return IndexedRawTextDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path): elif IndexedInMemoryDataset.exists(path):
return IndexedInMemoryDataset(path) return IndexedInMemoryDataset(path, fix_lua_indexing=True)
return None return None
src_dataset = indexed_dataset(prefix + src, self.src_dict) src_dataset = indexed_dataset(prefix + src, self.src_dict)
......
...@@ -40,8 +40,6 @@ class Trainer(object): ...@@ -40,8 +40,6 @@ class Trainer(object):
self.model = model.cuda() self.model = model.cuda()
self.criterion = criterion.cuda() self.criterion = criterion.cuda()
self.optimizer = None
# initialize meters # initialize meters
self.meters = OrderedDict() self.meters = OrderedDict()
self.meters['train_loss'] = AverageMeter() self.meters['train_loss'] = AverageMeter()
...@@ -61,10 +59,17 @@ class Trainer(object): ...@@ -61,10 +59,17 @@ class Trainer(object):
self._flat_grads = None self._flat_grads = None
self._num_updates = 0 self._num_updates = 0
self._optim_history = None self._optim_history = None
self._optimizer = None
@property
def optimizer(self):
if self._optimizer is None:
self._build_optimizer()
return self._optimizer
def _build_optimizer(self): def _build_optimizer(self):
self.optimizer = optim.build_optimizer(self.args, self.model.parameters()) self._optimizer = optim.build_optimizer(self.args, self.model.parameters())
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer) self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self._optimizer)
def save_checkpoint(self, filename, extra_state): def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file.""" """Save all training state in a checkpoint file."""
...@@ -93,7 +98,7 @@ class Trainer(object): ...@@ -93,7 +98,7 @@ class Trainer(object):
self._num_updates = last_optim['num_updates'] self._num_updates = last_optim['num_updates']
if 'train_meters' in extra_state: if extra_state is not None and 'train_meters' in extra_state:
self.meters = extra_state['train_meters'] self.meters = extra_state['train_meters']
del extra_state['train_meters'] del extra_state['train_meters']
...@@ -101,11 +106,6 @@ class Trainer(object): ...@@ -101,11 +106,6 @@ class Trainer(object):
def train_step(self, sample, update_params=True): def train_step(self, sample, update_params=True):
"""Do forward, backward and parameter update.""" """Do forward, backward and parameter update."""
if self.optimizer is None:
# initialize optimizer and LR scheduler if hasn't been loaded from the checkpoint
self._build_optimizer()
# Set seed based on args.seed and the update number so that we get # Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints # reproducible results when resuming from checkpoints
seed = self.args.seed + self.get_num_updates() seed = self.args.seed + self.get_num_updates()
......
...@@ -126,7 +126,7 @@ def _upgrade_state_dict(state): ...@@ -126,7 +126,7 @@ def _upgrade_state_dict(state):
if 'train_iterator' not in state['extra_state']: if 'train_iterator' not in state['extra_state']:
state['extra_state']['train_iterator'] = { state['extra_state']['train_iterator'] = {
'epoch': state['extra_state']['epoch'], 'epoch': state['extra_state']['epoch'],
'iterations_in_epoch': 0, 'iterations_in_epoch': state['extra_state'].get('batch_offset', 0),
} }
return state return state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment