Unverified Commit c7faf2cc authored by Kilian Kluge's avatar Kilian Kluge Committed by GitHub
Browse files

[Documentation] Warn that DataCollatorForWholeWordMask is limited to...

[Documentation] Warn that DataCollatorForWholeWordMask is limited to BertTokenizer-like tokenizers (#12371)

* Notify users that DataCollatorForWholeWordMask is limited to BertTokenier-like tokenizers

* Fix code formatting
parent ff5cdc08
...@@ -22,6 +22,7 @@ from torch.nn.utils.rnn import pad_sequence ...@@ -22,6 +22,7 @@ from torch.nn.utils.rnn import pad_sequence
from ..file_utils import PaddingStrategy from ..file_utils import PaddingStrategy
from ..modeling_utils import PreTrainedModel from ..modeling_utils import PreTrainedModel
from ..models.bert import BertTokenizer, BertTokenizerFast
from ..tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase from ..tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
...@@ -395,10 +396,17 @@ class DataCollatorForLanguageModeling: ...@@ -395,10 +396,17 @@ class DataCollatorForLanguageModeling:
@dataclass @dataclass
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
""" """
Data collator used for language modeling. Data collator used for language modeling that masks entire words.
- collates batches of tensors, honoring their tokenizer's pad_token - collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for masked language modeling - preprocesses batches for masked language modeling
.. note::
This collator relies on details of the implementation of subword tokenization by
:class:`~transformers.BertTokenizer`, specifically that subword tokens are prefixed with `##`. For tokenizers
that do not adhere to this scheme, this collator will produce an output that is roughly equivalent to
:class:`.DataCollatorForLanguageModeling`.
""" """
def __call__( def __call__(
...@@ -435,6 +443,11 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): ...@@ -435,6 +443,11 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
""" """
Get 0/1 labels for masked tokens with whole word mask proxy Get 0/1 labels for masked tokens with whole word mask proxy
""" """
if not isinstance(self.tokenizer, (BertTokenizer, BertTokenizerFast)):
warnings.warn(
"DataCollatorForWholeWordMask is only suitable for BertTokenizer-like tokenizers."
"Please refer to the documentation for more information."
)
cand_indexes = [] cand_indexes = []
for (i, token) in enumerate(input_tokens): for (i, token) in enumerate(input_tokens):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment