"...grounding_dino_mmcv.git" did not exist on "b12850fe31d825b481832cd3c7123f3e50cd4f4f"
Unverified Commit 6303b5a7 authored by Huang Lianzhe's avatar Huang Lianzhe Committed by GitHub
Browse files

[Bug Fix] The actual batch_size is inconsistent with the settings. (#7235)



* [bug fix] fixed the bug that the actual batch_size is inconsistent with the parameter settings

* reformat

* reformat

* reformat

* add support for dict and BatchEncoding

* add support for dict and BatchEncoding

* add documentation for DataCollatorForNextSentencePrediction

* Some more nits for the docstring
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Some more nits for the docstring
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Some more nits for the docstring
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Some more nits for the docstring
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Some more nits for the docstring
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* rename variables
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent c754c41c
import random
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
......@@ -402,8 +401,8 @@ class DataCollatorForPermutationLanguageModeling:
@dataclass
class DataCollatorForNextSentencePrediction:
"""
Data collator used for language modeling.
- collates batches of tensors, honoring their tokenizer's pad_token
Data collator used for next sentence prediction.
- collates examples which contains pre-generated negative examples
- preprocesses batches for masked language modeling
"""
......@@ -414,21 +413,30 @@ class DataCollatorForNextSentencePrediction:
nsp_probability: float = 0.5
mlm_probability: float = 0.15
def __call__(self, examples: List[Union[List[List[int]], Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], (dict, BatchEncoding)):
examples = [e["input_ids"] for e in examples]
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""
The input should contain negative examples, :class:`~transformers.DataCollatorForNextSentencePrediction` will not generate any negative examples.
Args:
examples (:obj:`List[Dict]`): Each dictionary should have the following keys:
- ``tokens_a``: A sequence of tokens, which should appear before ``tokens_b`` in the text.
- ``tokens_b``: A sequence of tokens, which should appear after ``tokens_a`` in the text.
- ``is_random_next``: 1 if this pair is generated randomly, else 0.
"""
tokens_a = [e["tokens_a"] for e in examples]
tokens_b = [e["tokens_b"] for e in examples]
nsp_labels = [1 if e["is_random_next"] else 0 for e in examples]
input_ids = []
segment_ids = []
attention_masks = []
nsp_labels = []
for i, doc in enumerate(examples):
input_id, segment_id, attention_mask, label = self.create_examples_from_document(doc, i, examples)
input_ids.extend(input_id)
segment_ids.extend(segment_id)
attention_masks.extend(attention_mask)
nsp_labels.extend(label)
assert len(tokens_a) == len(tokens_b)
for i in range(len(tokens_a)):
input_id, attention_mask, segment_id = self.create_features_from_example(tokens_a[i], tokens_b[i])
input_ids.append(input_id)
segment_ids.append(segment_id)
attention_masks.append(attention_mask)
if self.mlm:
input_ids, mlm_labels = self.mask_tokens(self._tensorize_batch(input_ids))
else:
......@@ -438,6 +446,7 @@ class DataCollatorForNextSentencePrediction:
"input_ids": input_ids,
"attention_mask": self._tensorize_batch(attention_masks),
"token_type_ids": self._tensorize_batch(segment_ids),
"masked_lm_labels": mlm_labels if self.mlm else None,
"next_sentence_label": torch.tensor(nsp_labels),
}
if self.mlm:
......@@ -457,82 +466,11 @@ class DataCollatorForNextSentencePrediction:
)
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
def create_examples_from_document(
self, document: List[List[int]], doc_index: int, examples: List[List[List[int]]]
):
def create_features_from_example(self, tokens_a, tokens_b):
"""Creates examples for a single document."""
max_num_tokens = self.block_size - self.tokenizer.num_special_tokens_to_add(pair=True)
# We *usually* want to fill up the entire sequence since we are padding
# to `block_size` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# `block_size` is a hard limit.
target_seq_length = max_num_tokens
if random.random() < self.short_seq_probability:
target_seq_length = random.randint(2, max_num_tokens)
current_chunk = [] # a buffer stored current working segments
current_length = 0
i = 0
input_ids = []
segment_ids = []
attention_masks = []
labels = []
while i < len(document):
segment = document[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end` is how many segments from `current_chunk` go into the `A`
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = random.randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
if len(current_chunk) == 1 or random.random() < self.nsp_probability:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
# This should rarely go for more than one iteration for large
# corpora. However, just to be careful, we try to make sure that
# the random document is not the same as the document
# we're processing. Also check to make sure that the random document
# is not empty.
for _ in range(10):
random_document_index = random.randint(0, len(examples) - 1)
if random_document_index != doc_index and len(examples[random_document_index]) > 0:
break
random_document = examples[random_document_index]
random_start = random.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
# We didn't actually use these segments so we "put them back" so
# they don't go to waste.
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
# Actual next
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
tokens_a, tokens_b, _ = self.tokenizer.truncate_sequences(
tokens_a,
tokens_b,
......@@ -551,17 +489,11 @@ class DataCollatorForNextSentencePrediction:
attention_mask.append(0)
segment_id.append(0)
input_ids.append(torch.tensor(input_id))
segment_ids.append(torch.tensor(segment_id))
attention_masks.append(torch.tensor(attention_mask))
labels.append(torch.tensor(1 if is_random_next else 0))
current_chunk = []
current_length = 0
i += 1
input_id = torch.tensor(input_id)
attention_mask = torch.tensor(attention_mask)
segment_id = torch.tensor(segment_id)
return input_ids, segment_ids, attention_masks, labels
return input_id, attention_mask, segment_id
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
......
......@@ -2,7 +2,7 @@ import os
import pickle
import random
import time
from typing import Dict, Optional
from typing import Dict, List, Optional
import torch
from torch.utils.data.dataset import Dataset
......@@ -267,10 +267,14 @@ class TextDatasetForNextSentencePrediction(Dataset):
file_path: str,
block_size: int,
overwrite_cache=False,
short_seq_probability=0.1,
nsp_probability=0.5,
):
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
block_size = block_size - tokenizer.num_special_tokens_to_add(pair=True)
self.block_size = block_size - tokenizer.num_special_tokens_to_add(pair=True)
self.short_seq_probability = short_seq_probability
self.nsp_probability = nsp_probability
directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(
......@@ -283,7 +287,6 @@ class TextDatasetForNextSentencePrediction(Dataset):
)
self.tokenizer = tokenizer
self.examples = []
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
......@@ -313,7 +316,7 @@ class TextDatasetForNextSentencePrediction(Dataset):
else:
logger.info(f"Creating features from dataset file at {directory}")
self.examples = [[]]
self.documents = [[]]
with open(file_path, encoding="utf-8") as f:
while True:
line = f.readline()
......@@ -322,12 +325,17 @@ class TextDatasetForNextSentencePrediction(Dataset):
line = line.strip()
# Empty lines are used as document delimiters
if not line and len(self.examples[-1]) != 0:
self.examples.append([])
if not line and len(self.documents[-1]) != 0:
self.documents.append([])
tokens = tokenizer.tokenize(line)
tokens = tokenizer.convert_tokens_to_ids(tokens)
if tokens:
self.examples[-1].append(tokens)
self.documents[-1].append(tokens)
logger.info(f"Creating examples from {len(self.documents)} documents.")
self.examples = []
for doc_index, document in enumerate(self.documents):
self.create_examples_from_document(document, doc_index)
start = time.time()
with open(cached_features_file, "wb") as handle:
......@@ -336,6 +344,85 @@ class TextDatasetForNextSentencePrediction(Dataset):
"Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
)
def create_examples_from_document(self, document: List[List[int]], doc_index: int):
"""Creates examples for a single document."""
max_num_tokens = self.block_size - self.tokenizer.num_special_tokens_to_add(pair=True)
# We *usually* want to fill up the entire sequence since we are padding
# to `block_size` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# `block_size` is a hard limit.
target_seq_length = max_num_tokens
if random.random() < self.short_seq_probability:
target_seq_length = random.randint(2, max_num_tokens)
current_chunk = [] # a buffer stored current working segments
current_length = 0
i = 0
while i < len(document):
segment = document[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end` is how many segments from `current_chunk` go into the `A`
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = random.randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
if len(current_chunk) == 1 or random.random() < self.nsp_probability:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
# This should rarely go for more than one iteration for large
# corpora. However, just to be careful, we try to make sure that
# the random document is not the same as the document
# we're processing.
for _ in range(10):
random_document_index = random.randint(0, len(self.documents) - 1)
if random_document_index != doc_index:
break
random_document = self.documents[random_document_index]
random_start = random.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
# We didn't actually use these segments so we "put them back" so
# they don't go to waste.
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
# Actual next
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
self.examples.append(
{"tokens_a": tokens_a, "tokens_b": tokens_b, "is_random_next": is_random_next}
)
current_chunk = []
current_length = 0
i += 1
def __len__(self):
return len(self.examples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment