"docs/source/vscode:/vscode.git/clone" did not exist on "d155b38d6ea70fef3dec2e1f678269e713672bb7"
Commit 594202a9 authored by VictorSanh's avatar VictorSanh Committed by Victor SANH
Browse files

lm_seqs_dataset

parent 38084507
......@@ -12,30 +12,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Dataloaders to train DistilBERT
""" Dataset to distilled models
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
"""
from typing import List
import math
from itertools import chain
from collections import Counter
import numpy as np
import torch
from torch.utils.data import Dataset
import numpy as np
from utils import logger
class Dataset:
class LmSeqsDataset(Dataset):
"""Custom Dataset wrapping language modeling sequences.
Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths.
Input:
------
params: `NameSpace` parameters
data: `List[np.array[int]]
"""
def __init__(self,
params,
data):
self.params = params
self.tokens_per_batch = params.tokens_per_batch
self.batch_size = params.batch_size
self.shuffle = params.shuffle
self.group_by_size = params.group_by_size
self.token_ids = np.array(data)
self.lengths = np.uint16([len(t) for t in data])
self.lengths = np.array([len(t) for t in data])
self.check()
self.remove_long_sequences()
......@@ -43,6 +46,9 @@ class Dataset:
self.check()
self.print_statistics()
def __getitem__(self, index):
return (self.token_ids[index], self.lengths[index])
def __len__(self):
return len(self.lengths)
......@@ -51,12 +57,14 @@ class Dataset:
Some sanity checks
"""
assert len(self.token_ids) == len(self.lengths)
assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths)))
def remove_long_sequences(self):
"""
Sequences that are too long are splitted by chunk of max_position_embeddings.
Sequences that are too long are splitted by chunk of max_model_input_size.
"""
indices = self.lengths >= self.params.max_position_embeddings
max_len = self.params.max_model_input_size
indices = self.lengths > max_len
logger.info(f'Splitting {sum(indices)} too long sequences.')
def divide_chunks(l, n):
......@@ -64,10 +72,13 @@ class Dataset:
new_tok_ids = []
new_lengths = []
cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token']
max_len = self.params.max_position_embeddings
if self.params.mlm:
cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token']
else:
cls_id, sep_id = self.params.special_tok_ids['bos_token'], self.params.special_tok_ids['eos_token']
for seq_, len_ in zip(self.token_ids, self.lengths):
assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
if len_ <= max_len:
new_tok_ids.append(seq_)
new_lengths.append(len_)
......@@ -79,6 +90,7 @@ class Dataset:
if sub_s[-1] != sep_id:
sub_s = np.insert(sub_s, len(sub_s), sep_id)
assert len(sub_s) <= max_len
assert (sub_s[0] == cls_id) and (sub_s[-1] == sep_id), sub_s
sub_seqs.append(sub_s)
new_tok_ids.extend(sub_seqs)
......@@ -113,89 +125,27 @@ class Dataset:
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
def select_data(self, a: int, b: int):
"""
Select a subportion of the data.
"""
n_sequences = len(self)
assert 0 <= a < b <= n_sequences, ValueError(f'`0 <= a < b <= n_sequences` is not met with a={a} and b={b}')
logger.info(f'Selecting sequences from {a} to {b} (excluded).')
self.token_ids = self.token_ids[a:b]
self.lengths = self.lengths[a:b]
self.check()
def split(self):
"""
Distributed training: split the data accross the processes.
"""
assert self.params.n_gpu > 1
logger.info('Splitting the data accross the processuses.')
n_seq = len(self)
n_seq_per_procesus = n_seq // self.params.world_size
a = n_seq_per_procesus * self.params.global_rank
b = a + n_seq_per_procesus
self.select_data(a=a, b=b)
def batch_sequences(self,
token_ids: List[List[int]],
lengths: List[int]):
batch):
"""
Do the padding and transform into torch.tensor.
"""
token_ids = [t[0] for t in batch]
lengths = [t[1] for t in batch]
assert len(token_ids) == len(lengths)
# Max for paddings
max_seq_len_ = max(lengths)
# Pad token ids
pad_idx = self.params.special_tok_ids['pad_token']
if self.params.mlm:
pad_idx = self.params.special_tok_ids['pad_token']
else:
pad_idx = self.params.special_tok_ids['unk_token']
tk_ = [list(t.astype(int)) + [pad_idx]*(max_seq_len_-len(t)) for t in token_ids]
assert len(tk_) == len(token_ids)
assert all(len(t) == max_seq_len_ for t in tk_)
tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
lg_t = torch.tensor(lengths.astype(int)) # (bs)
tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
lg_t = torch.tensor(lengths) # (bs)
return tk_t, lg_t
def get_batches_iterator(self,
batches):
"""
Return an iterator over batches.
"""
for sequences_ids in batches:
token_ids, lengths = self.batch_sequences(self.token_ids[sequences_ids],
self.lengths[sequences_ids])
yield (token_ids, lengths)
def get_iterator(self,
seed: int = None):
"""
Return a data iterator.
"""
rng = np.random.RandomState(seed)
n_sequences = len(self)
indices = np.arange(n_sequences)
if self.group_by_size:
indices = indices[np.argsort(self.lengths[indices], kind='mergesort')]
if self.tokens_per_batch == -1:
batches = np.array_split(indices, math.ceil(len(indices) * 1. / self.batch_size))
else:
assert self.tokens_per_batch > 0
batch_ids = np.cumsum(self.lengths[indices]) // self.tokens_per_batch
_, bounds = np.unique(batch_ids, return_index=True)
batches = [indices[bounds[i]:bounds[i + 1]] for i in range(len(bounds) - 1)]
if bounds[-1] < len(indices):
batches.append(indices[bounds[-1]:])
if self.shuffle:
rng.shuffle(batches)
assert n_sequences == sum([len(x) for x in batches])
assert self.lengths[indices].sum() == sum([self.lengths[x].sum() for x in batches])
return self.get_batches_iterator(batches=batches)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment