"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "bfe870be654a1fc54c5479f9ad0875492d9cd959"
Unverified Commit 762cba3b authored by Yu Liu's avatar Yu Liu Committed by GitHub
Browse files

Albert pretrain datasets/ datacollator (#6168)



* add dataset for albert pretrain

* datacollator for albert pretrain

* naming, comprehension, file reading change

* data cleaning is no needed after this modification

* delete prints

* fix a bug

* file structure change

* add tests for albert datacollator

* remove random seed

* add back len and get item function

* sample file for testing and test code added

* format change for black

* more format change

* Style

* var assignment issue resolve

* add back wrongly deleted DataCollatorWithPadding in init file

* Style
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent 49e9be06
...@@ -207,6 +207,7 @@ if is_torch_available(): ...@@ -207,6 +207,7 @@ if is_torch_available():
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
DataCollatorForNextSentencePrediction, DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling, DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
DataCollatorWithPadding, DataCollatorWithPadding,
default_data_collator, default_data_collator,
) )
...@@ -214,6 +215,7 @@ if is_torch_available(): ...@@ -214,6 +215,7 @@ if is_torch_available():
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
LineByLineTextDataset, LineByLineTextDataset,
LineByLineWithSOPTextDataset,
SquadDataset, SquadDataset,
SquadDataTrainingArguments, SquadDataTrainingArguments,
TextDataset, TextDataset,
......
...@@ -198,6 +198,75 @@ class DataCollatorForLanguageModeling: ...@@ -198,6 +198,75 @@ class DataCollatorForLanguageModeling:
return inputs, labels return inputs, labels
@dataclass
class DataCollatorForSOP(DataCollatorForLanguageModeling):
"""
Data collator used for sentence order prediction task.
- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for both masked language modeling and sentence order prediction
"""
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
input_ids = [example["input_ids"] for example in examples]
input_ids = self._tensorize_batch(input_ids)
input_ids, labels, attention_mask = self.mask_tokens(input_ids)
token_type_ids = [example["token_type_ids"] for example in examples]
# size of segment_ids varied because randomness, padding zero to the end as the orignal implementation
token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
sop_label_list = [example["sentence_order_label"] for example in examples]
sentence_order_label = torch.stack(sop_label_list)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"sentence_order_label": sentence_order_label,
}
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels/attention_mask for masked language modeling: 80% MASK, 10% random, 10% original.
N-gram not applied yet.
"""
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
# probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
attention_mask = (~masked_indices).float()
if self.tokenizer._pad_token is not None:
attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
attention_mask.masked_fill_(attention_padding_mask, value=1.0)
labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels, attention_mask
@dataclass @dataclass
class DataCollatorForPermutationLanguageModeling: class DataCollatorForPermutationLanguageModeling:
""" """
......
...@@ -3,5 +3,10 @@ ...@@ -3,5 +3,10 @@
# module, but to preserve other warnings. So, don't check this module at all. # module, but to preserve other warnings. So, don't check this module at all.
from .glue import GlueDataset, GlueDataTrainingArguments from .glue import GlueDataset, GlueDataTrainingArguments
from .language_modeling import LineByLineTextDataset, TextDataset, TextDatasetForNextSentencePrediction from .language_modeling import (
LineByLineTextDataset,
LineByLineWithSOPTextDataset,
TextDataset,
TextDatasetForNextSentencePrediction,
)
from .squad import SquadDataset, SquadDataTrainingArguments from .squad import SquadDataset, SquadDataTrainingArguments
import os import os
import pickle import pickle
import random
import time import time
from typing import Optional from typing import Dict, Optional
import torch import torch
from torch.utils.data.dataset import Dataset from torch.utils.data.dataset import Dataset
...@@ -113,6 +114,147 @@ class LineByLineTextDataset(Dataset): ...@@ -113,6 +114,147 @@ class LineByLineTextDataset(Dataset):
return torch.tensor(self.examples[i], dtype=torch.long) return torch.tensor(self.examples[i], dtype=torch.long)
class LineByLineWithSOPTextDataset(Dataset):
"""
Dataset for sentence order prediction task, prepare sentence pairs for SOP task
"""
def __init__(self, tokenizer: PreTrainedTokenizer, file_dir: str, block_size: int):
assert os.path.isdir(file_dir)
logger.info(f"Creating features from dataset file folder at {file_dir}")
self.examples = []
# TODO: randomness could apply a random seed, ex. rng = random.Random(random_seed)
# file path looks like ./dataset/wiki_1, ./dataset/wiki_2
for file_name in os.listdir(file_dir):
file_path = os.path.join(file_dir, file_name)
assert os.path.isfile(file_path)
article_open = False
with open(file_path, encoding="utf-8") as f:
original_lines = f.readlines()
article_lines = []
for line in original_lines:
if "<doc id=" in line:
article_open = True
elif "</doc>" in line:
article_open = False
document = [
tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line))
for line in article_lines[1:]
if (len(line) > 0 and not line.isspace())
]
examples = self.create_examples_from_document(document, block_size, tokenizer)
self.examples.extend(examples)
article_lines = []
else:
if article_open:
article_lines.append(line)
logger.info("Dataset parse finished.")
def create_examples_from_document(self, document, block_size, tokenizer, short_seq_prob=0.1):
"""Creates examples for a single document."""
# Account for special tokens
max_num_tokens = block_size - tokenizer.num_special_tokens_to_add(pair=True)
# We *usually* want to fill up the entire sequence since we are padding
# to `block_size` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# `block_size` is a hard limit.
target_seq_length = max_num_tokens
if random.random() < short_seq_prob:
target_seq_length = random.randint(2, max_num_tokens)
# We DON'T just concatenate all of the tokens from a document into a long
# sequence and choose an arbitrary split point because this would make the
# next sentence prediction task too easy. Instead, we split the input into
# segments "A" and "B" based on the actual "sentences" provided by the user
# input.
examples = []
current_chunk = [] # a buffer stored current working segments
current_length = 0
i = 0
while i < len(document):
segment = document[i] # get a segment
if not segment:
i += 1
continue
current_chunk.append(segment) # add a segment to current chunk
current_length += len(segment) # overall token length
# if current length goes to the target length or reaches the end of file, start building token a and b
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end` is how many segments from `current_chunk` go into the `A` (first) sentence.
a_end = 1
# if current chunk has more than 2 sentences, pick part of it `A` (first) sentence
if len(current_chunk) >= 2:
a_end = random.randint(1, len(current_chunk) - 1)
# token a
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
# token b
tokens_b = []
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
if len(tokens_a) == 0 or len(tokens_b) == 0:
continue
# switch tokens_a and tokens_b randomly
if random.random() < 0.5:
is_next = False
tokens_a, tokens_b = tokens_b, tokens_a
else:
is_next = True
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens):
"""Truncates a pair of sequences to a maximum sequence length."""
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_num_tokens:
break
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
# We want to sometimes truncate from the front and sometimes from the
# back to add more randomness and avoid biases.
if random.random() < 0.5:
del trunc_tokens[0]
else:
trunc_tokens.pop()
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
# add special tokens
input_ids = tokenizer.build_inputs_with_special_tokens(tokens_a, tokens_b)
# add token type ids, 0 for sentence a, 1 for sentence b
token_type_ids = tokenizer.create_token_type_ids_from_sequences(tokens_a, tokens_b)
example = {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
"sentence_order_label": torch.tensor(0 if is_next else 1, dtype=torch.long),
}
examples.append(example)
current_chunk = [] # clear current chunk
current_length = 0 # reset current text length
i += 1 # go to next line
return examples
def __len__(self):
return len(self.examples)
def __getitem__(self, i) -> Dict[str, torch.tensor]:
return self.examples[i]
class TextDatasetForNextSentencePrediction(Dataset): class TextDatasetForNextSentencePrediction(Dataset):
""" """
This will be superseded by a framework-agnostic approach This will be superseded by a framework-agnostic approach
......
This diff is collapsed.
...@@ -11,9 +11,11 @@ if is_torch_available(): ...@@ -11,9 +11,11 @@ if is_torch_available():
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
DataCollatorForNextSentencePrediction, DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling, DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP,
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
LineByLineTextDataset, LineByLineTextDataset,
LineByLineWithSOPTextDataset,
TextDataset, TextDataset,
TextDatasetForNextSentencePrediction, TextDatasetForNextSentencePrediction,
default_data_collator, default_data_collator,
...@@ -21,6 +23,7 @@ if is_torch_available(): ...@@ -21,6 +23,7 @@ if is_torch_available():
PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt" PATH_SAMPLE_TEXT = "./tests/fixtures/sample_text.txt"
PATH_SAMPLE_TEXT_DIR = "./tests/fixtures/tests_samples/wiki_text"
@require_torch @require_torch
...@@ -168,3 +171,19 @@ class DataCollatorIntegrationTest(unittest.TestCase): ...@@ -168,3 +171,19 @@ class DataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512))) self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((total_samples, 512))) self.assertEqual(batch["masked_lm_labels"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,))) self.assertEqual(batch["next_sentence_label"].shape, torch.Size((total_samples,)))
def test_sop(self):
tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
data_collator = DataCollatorForSOP(tokenizer)
dataset = LineByLineWithSOPTextDataset(tokenizer, file_dir=PATH_SAMPLE_TEXT_DIR, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
# Since there are randomly generated false samples, the total number of samples is not fixed.
total_samples = batch["input_ids"].shape[0]
self.assertEqual(batch["input_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["token_type_ids"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((total_samples, 512)))
self.assertEqual(batch["sentence_order_label"].shape, torch.Size((total_samples,)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment