Unverified Commit 762cba3b authored by Yu Liu's avatar Yu Liu Committed by GitHub
Browse files

Albert pretrain datasets/ datacollator (#6168)



* add dataset for albert pretrain

* datacollator for albert pretrain

* naming, comprehension, file reading change

* data cleaning is no needed after this modification

* delete prints

* fix a bug

* file structure change

* add tests for albert datacollator

* remove random seed

* add back len and get item function

* sample file for testing and test code added

* format change for black

* more format change

* Style

* var assignment issue resolve

* add back wrongly deleted DataCollatorWithPadding in init file

* Style
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent 49e9be06
......@@ -207,6 +207,7 @@ if is_torch_available():
DataCollatorForLanguageModeling,
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
DataCollatorWithPadding,
default_data_collator,
)
......@@ -214,6 +215,7 @@ if is_torch_available():
GlueDataset,
GlueDataTrainingArguments,
LineByLineTextDataset,
LineByLineWithSOPTextDataset,
SquadDataset,
SquadDataTrainingArguments,
TextDataset,
......
......@@ -198,6 +198,75 @@ class DataCollatorForLanguageModeling:
return inputs, labels
@dataclass
class DataCollatorForSOP(DataCollatorForLanguageModeling):
"""
Data collator used for sentence order prediction task.
- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for both masked language modeling and sentence order prediction
"""
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
input_ids = [example["input_ids"] for example in examples]
input_ids = self._tensorize_batch(input_ids)
input_ids, labels, attention_mask = self.mask_tokens(input_ids)
token_type_ids = [example["token_type_ids"] for example in examples]
# size of segment_ids varied because randomness, padding zero to the end as the orignal implementation
token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
sop_label_list = [example["sentence_order_label"] for example in examples]
sentence_order_label = torch.stack(sop_label_list)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"sentence_order_label": sentence_order_label,
}
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels/attention_mask for masked language modeling: 80% MASK, 10% random, 10% original.
N-gram not applied yet.
"""
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
# probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
attention_mask = (~masked_indices).float()
if self.tokenizer._pad_token is not None:
attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
attention_mask.masked_fill_(attention_padding_mask, value=1.0)
labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels, attention_mask
@dataclass
class DataCollatorForPermutationLanguageModeling:
"""
......
......@@ -3,5 +3,10 @@
# module, but to preserve other warnings. So, don't check this module at all.
from .glue import GlueDataset, GlueDataTrainingArguments
from .language_modeling import LineByLineTextDataset, TextDataset, TextDatasetForNextSentencePrediction
from .language_modeling import (
LineByLineTextDataset,
LineByLineWithSOPTextDataset,
TextDataset,
TextDatasetForNextSentencePrediction,
)
from .squad import SquadDataset, SquadDataTrainingArguments
import os
import pickle
import random
import time
from typing import Optional
from typing import Dict, Optional
import torch
from torch.utils.data.dataset import Dataset
......@@ -113,6 +114,147 @@ class LineByLineTextDataset(Dataset):
return torch.tensor(self.examples[i], dtype=torch.long)
class LineByLineWithSOPTextDataset(Dataset):
"""
Dataset for sentence order prediction task, prepare sentence pairs for SOP task
"""
def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int):
assert os.path.isdir(file_dir)
logger.info(f"Creating features from dataset file folder at {file_dir}")
self.examples = []
# TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
# file path looks like ./dataset/wiki_1, ./dataset/wiki_2
for file_name in os.listdir(file_dir):
file_path = os.path.join(file_dir, file_name)
assert os.path.isfile(file_path)
article_open = False
with open(file_path, encoding="utf-8") as f:
original_lines = f.readlines()
article_lines = []
for line in original_lines:
if "<doc id=" in line:
article_open = True
elif "</doc>" in line:
article_open = False
document = [
tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))
for line in article_lines[1:]
if (len(line) > 0 and not line.isspace())
]
examples = self.create_examples_from_document(document, block_size, tokenizer)
self.examples.extend(examples)
article_lines = []
else:
if article_open:
article_lines.append(line)
logger.info("Dataset parse finished.")
def create_examples_from_document(self, document, block_size, tokenizer, short_seq_prob=0.1):
"""Creates examples for a single document."""
# Account for special tokens
max_num_tokens = block_size - tokenizer.num_special_tokens_to_add(pair=True)
# We *usually* want to fill up the entire sequence since we are padding
# to `block_size` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# `block_size` is a hard limit.
target_seq_length = max_num_tokens
if random.random() < short_seq_prob:
target_seq_length = random.randint(2, max_num_tokens)
# We DON'T just concatenate all of the tokens from a document into a long
# sequence and choose an arbitrary split point because this would make the
# next sentence prediction task too easy. Instead, we split the input into
# segments "A" and "B" based on the actual "sentences" provided by the user
# input.
examples = []
current_chunk = [] # a buffer stored current working segments
current_length = 0
i = 0
while i < len(document):
segment = document[i] # get a segment
if not segment:
i += 1
continue
current_chunk.append(segment) # add a segment to current chunk
current_length += len(segment) # overall token length
# if current length goes to the target length or reaches the end of file, start building token a and b
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end` is how many segments from `current_chunk` go into the `A` (first) sentence.
a_end = 1
# if current chunk has more than 2 sentences, pick part of it `A` (first) sentence
if len(current_chunk) >= 2:
a_end = random.randint(1, len(current_chunk) - 1)
# token a
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
# token b
tokens_b = []
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
if len(tokens_a) == 0 or len(tokens_b) == 0:
continue
# switch tokens_a and tokens_b randomly
if random.random() < 0.5:
is_next = False
tokens_a, tokens_b = tokens_b, tokens_a
else:
is_next = True
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
"""Truncates a pair of sequences to a maximum sequence length."""
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_num_tokens:
break
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
# We want to sometimes truncate from the front and sometimes from the
# back to add more randomness and avoid biases.
if random.random() < 0.5:
del trunc_tokens[0]
else:
trunc_tokens.pop()
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
# add special tokens
input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
# add token type ids, 0 for sentence a, 1 for sentence b
token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
example = {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
"sentence_order_label": torch.tensor(0 if is_next else 1, dtype=torch.long),
}
examples.append(example)
current_chunk = [] # clear current chunk
current_length = 0 # reset current text length
i += 1 # go to next line
return examples
def __len__(self):
return len(self.examples)
def __getitem__(self, i) -> Dict[str, torch.tensor]:
return self.examples[i]
class TextDatasetForNextSentencePrediction(Dataset):
"""
This will be superseded by a framework-agnostic approach
......
This diff is collapsed.
......@@ -11,9 +11,11 @@ if is_torch_available():
DataCollatorForLanguageModeling,
DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
GlueDataset,
GlueDataTrainingArguments,
LineByLineTextDataset,
LineByLineWithSOPTextDataset,
TextDataset,
TextDatasetForNextSentencePrediction,
default_data_collator,
......@@ -21,6 +23,7 @@ if is_torch_available():
PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
PATH_SAMPLE_TEXT_DIR = "./tests/fixtures/tests_samples/wiki_text"
@require_torch
......@@ -168,3 +171,19 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,)))
def test_sop(self):
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
data_collator = DataCollatorForSOP(tokenizer)
dataset = LineByLineWithSOPTextDataset(tokenizer, file_dir=PATH_SAMPLE_TEXT_DIR, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
# Since there are randomly generated false samples, the total number of samples is not fixed.
total_samples = batch["input_ids"].shape[0]
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((total_samples,)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment