"git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "4a60b45d4c4c80fa934d33e51edfcd29f9795470"
Commit d9c79133 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add 'doc' break mode to TokenBlockDataset

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/679

Test Plan: https://our.intern.facebook.com/intern/chronos/jobinstance/?jobinstanceid=5191319216&smc=chronos_gp_admin_client&log_type=stdout&offset=0&pretty_logs=false

Differential Revision: D15961008

Pulled By: myleott

fbshipit-source-id: cf214de96665b33887ef64cfcb45a51f81002ed1
parent efb43450
...@@ -26,12 +26,20 @@ class TokenBlockDataset(FairseqDataset): ...@@ -26,12 +26,20 @@ class TokenBlockDataset(FairseqDataset):
- 'complete': break tokens into blocks (up to block_size) such that - 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size exceeded if some sentences exceed block_size
- 'complete_doc': similar to 'complete' mode, but do not
cross document boundaries
- 'eos': each block contains one sentence (block_size is ignored) - 'eos': each block contains one sentence (block_size is ignored)
include_targets (bool, optional): return next tokens as targets include_targets (bool, optional): return next tokens as targets
(default: False). (default: False).
document_sep_len (int, optional): document separator size (required for
'complete_doc' break mode). Typically 1 if the sentences have eos
and 0 otherwise.
""" """
def __init__(self, dataset, sizes, block_size, pad, eos, break_mode=None, include_targets=False): def __init__(
self, dataset, sizes, block_size, pad, eos, break_mode=None,
include_targets=False, document_sep_len=1,
):
super().__init__() super().__init__()
self.dataset = dataset self.dataset = dataset
self.pad = pad self.pad = pad
...@@ -66,6 +74,27 @@ class TokenBlockDataset(FairseqDataset): ...@@ -66,6 +74,27 @@ class TokenBlockDataset(FairseqDataset):
curr_size = 0 curr_size = 0
if curr_size > 0: if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size)) self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'complete_doc':
tok_idx = 0
sz_idx = 0
curr_size = 0
while sz_idx < len(sizes):
if (
(curr_size + sizes[sz_idx] <= block_size or curr_size == 0)
# an empty sentence indicates end-of-document:
and sizes[sz_idx] != document_sep_len
):
curr_size += sizes[sz_idx]
sz_idx += 1
else:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if sizes[sz_idx] == document_sep_len:
tok_idx += sizes[sz_idx]
sz_idx += 1
if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos': elif break_mode == 'eos':
self.slice_indices = np.empty((len(sizes), 2), dtype=int) self.slice_indices = np.empty((len(sizes), 2), dtype=int)
if not torch.is_tensor(sizes): if not torch.is_tensor(sizes):
...@@ -92,27 +121,21 @@ class TokenBlockDataset(FairseqDataset): ...@@ -92,27 +121,21 @@ class TokenBlockDataset(FairseqDataset):
1, 1,
) )
else: else:
ds = DatasetSearcher(sizes)
self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int) self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int)
ds_idx, ds_remaining = -1, 0
for i, (s, e) in enumerate(self.slice_indices): for i, (s, e) in enumerate(self.slice_indices):
to_consume = e - s ds.seek(s)
if ds_remaining == 0: start_ds_idx = ds.current_index
ds_idx += 1 start_offset = ds.current_offset
ds_remaining = sizes[ds_idx] if e <= s:
start_ds_idx = ds_idx continue
start_offset = sizes[ds_idx] - ds_remaining ds.seek(e - 1)
while to_consume > ds_remaining: end_ds_idx = ds.current_index
to_consume -= ds_remaining
ds_idx += 1
ds_remaining = sizes[ds_idx]
ds_remaining -= to_consume
self.block_to_dataset_index[i] = ( self.block_to_dataset_index[i] = (
start_ds_idx, # starting index in dataset start_ds_idx, # starting index in dataset
start_offset, # starting offset within starting index start_offset, # starting offset within starting index
ds_idx, # ending index in dataset end_ds_idx, # ending index in dataset
) )
assert ds_remaining == 0
assert ds_idx == len(self.dataset) - 1
def __getitem__(self, index): def __getitem__(self, index):
start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index] start_ds_idx, start_offset, end_ds_idx = self.block_to_dataset_index[index]
...@@ -155,3 +178,34 @@ class TokenBlockDataset(FairseqDataset): ...@@ -155,3 +178,34 @@ class TokenBlockDataset(FairseqDataset):
for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]] for start_ds_idx, _, end_ds_idx in [self.block_to_dataset_index[index]]
for ds_idx in range(start_ds_idx, end_ds_idx + 1) for ds_idx in range(start_ds_idx, end_ds_idx + 1)
}) })
class DatasetSearcher(object):
"""Helper for mapping "flat" indices to indices and offsets in an
underlying dataset."""
def __init__(self, sizes):
self.sizes = sizes
self.reset()
def reset(self):
self.current_index = 0 # index in underlying dataset
self.current_offset = 0 # offset within current index in underlying dataset
self.current_i = 0 # "flat" index
def seek(self, i):
assert i >= 0
if i < self.current_i:
self.reset()
if i > self.current_i:
to_consume = i - self.current_i
remaining = self.sizes[self.current_index] - self.current_offset
if remaining > to_consume:
self.current_offset += to_consume
self.current_i += to_consume
else:
self.current_i += remaining
self.current_index += 1
self.current_offset = 0
self.seek(i)
assert self.current_i == i
...@@ -59,11 +59,12 @@ class LanguageModelingTask(FairseqTask): ...@@ -59,11 +59,12 @@ class LanguageModelingTask(FairseqTask):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
# fmt: off # fmt: off
parser.add_argument('data', help='path to data directory') parser.add_argument('data', help='path to data directory')
parser.add_argument('--sample-break-mode', parser.add_argument('--sample-break-mode', default='none',
choices=['none', 'complete', 'eos'], choices=['none', 'complete', 'complete_doc', 'eos'],
help='If omitted or "none", fills each sample with tokens-per-sample ' help='If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end ' 'tokens. If set to "complete", splits samples only at the end '
'of sentence, but may include multiple sentences per sample. ' 'of sentence, but may include multiple sentences per sample. '
'"complete_doc" is similar but respects doc boundaries. '
'If set to "eos", includes only one sentence per sample.') 'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=1024, type=int, parser.add_argument('--tokens-per-sample', default=1024, type=int,
help='max number of tokens per sample for LM dataset') help='max number of tokens per sample for LM dataset')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment