Commit 47fbc491 authored by myleott's avatar myleott
Browse files

fbshipit-source-id: 682b375c6e7535f12faaf9ca32811051f9e874da

parent cfeb2163
......@@ -134,8 +134,7 @@ class IndexedDataset(torch.utils.data.Dataset):
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and
os.path.exists(data_file_path(path))
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
@property
......@@ -432,8 +431,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and
os.path.exists(data_file_path(path))
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
......
......@@ -43,7 +43,7 @@ class RoundRobinZipDatasets(FairseqDataset):
def _map_index(self, key, index):
assert self._ordered_indices is not None, \
'Must call RoundRobinZipDatasets.ordered_indices() first'
'Must call RoundRobinZipDatasets.ordered_indices() first'
return self._ordered_indices[key][index % len(self.datasets[key])]
def __getitem__(self, index):
......
......@@ -40,6 +40,7 @@ class TokenBlockDataset(FairseqDataset):
self.slice_indices = []
assert len(dataset) == len(sizes)
assert len(dataset) > 0
sizes = np.array(sizes, dtype=int)
if break_mode is None or break_mode == 'none':
total_size = sum(sizes)
......@@ -71,7 +72,8 @@ class TokenBlockDataset(FairseqDataset):
sizes = torch.tensor(sizes)
cumsum = torch.cumsum(sizes, dim=0)
self.slice_indices[0] = [0, sizes[0]]
self.slice_indices[1:] = cumsum.unfold(0, 2, 1)
if len(cumsum) > 1:
self.slice_indices[1:] = cumsum.unfold(0, 2, 1)
else:
raise ValueError('Invalid break_mode: ' + break_mode)
......
......@@ -252,6 +252,11 @@ class tensorboard_log_wrapper(progress_bar):
self._log_to_tensorboard(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)
def __exit__(self, *exc):
for writer in getattr(self, '_writers', {}).values():
writer.close()
return False
def _log_to_tensorboard(self, stats, tag='', step=None):
writer = self._writer(tag)
if writer is None:
......
......@@ -181,6 +181,5 @@ class CrossLingualLMTask(FairseqTask):
dataset_map, default_key=self.default_key
)
print('| {} {} {} examples'.format(
self.args.data.split(':')[epoch], split, len(self.datasets[split])
)
self.args.data.split(':')[epoch], split, len(self.datasets[split]))
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment