Unverified Commit a16e568f authored by wlhgtc's avatar wlhgtc Committed by GitHub
Browse files

# Add whole word mask support for lm fine-tune (#7925)



* ADD: add whole word mask proxy for both eng and chinese

* MOD: adjust format

* MOD: reformat code

* MOD: update import

* MOD: fix bug

* MOD: add import

* MOD: fix bug

* MOD: decouple code and update readme

* MOD: reformat code

* Update examples/language-modeling/README.md
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/README.md
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/run_language_modeling.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/run_language_modeling.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/run_language_modeling.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update examples/language-modeling/run_language_modeling.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* change wwm to whole_word_mask

* reformat code

* reformat

* format

* Code quality

* ADD: update chinese ref readme

* MOD: small changes

* MOD: small changes2

* update readme
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <sylvain.gugger@gmail.com>
parent 64b4d25c
...@@ -45,9 +45,56 @@ slightly slower (over-fitting takes more epochs). ...@@ -45,9 +45,56 @@ slightly slower (over-fitting takes more epochs).
We use the `--mlm` flag so that the script may change its loss function. We use the `--mlm` flag so that the script may change its loss function.
If using whole-word masking, use both the`--mlm` and `--wwm` flags.
```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export TEST_FILE=/path/to/dataset/wiki.test.raw
python run_language_modeling.py \
--output_dir=output \
--model_type=roberta \
--model_name_or_path=roberta-base \
--do_train \
--train_data_file=$TRAIN_FILE \
--do_eval \
--eval_data_file=$TEST_FILE \
--mlm \
--wwm
```
For Chinese models, it's same with English model with only --mlm`. If using whole-word masking, we need to generate a reference files, case it's char level.
**Q :** Why ref file ?
**A :** Suppose we have a Chinese sentence like : `我喜欢你` The original Chinese-BERT will tokenize it as `['我','喜','欢','你']` in char level.
Actually, `喜欢` is a whole word. For whole word mask proxy, We need res like `['我','喜','##欢','你']`.
So we need a ref file to tell model which pos of BERT original token should be added `##`.
**Q :** Why LTP ?
**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT. It works well on so many Chines Task like CLUE (Chinese GLUE).
They use LTP, so if we want to fine-tune their model, we need LTP.
```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export LTP_RESOURCE=/path/to/ltp/tokenizer
export BERT_RESOURCE=/path/to/bert/tokenizer
export SAVE_PATH=/path/to/data/ref.txt
python chinese_ref.py \
--file_name=$TRAIN_FILE \
--ltp=$LTP_RESOURCE
--bert=$BERT_RESOURCE \
--save_path=$SAVE_PATH
```
Now Chinese Ref is only supported by `LineByLineWithRefDataset` Class, so we need add `line_by_line` flag:
```bash ```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export TEST_FILE=/path/to/dataset/wiki.test.raw export TEST_FILE=/path/to/dataset/wiki.test.raw
export REF_FILE=/path/to/ref.txt
python run_language_modeling.py \ python run_language_modeling.py \
--output_dir=output \ --output_dir=output \
...@@ -55,9 +102,12 @@ python run_language_modeling.py \ ...@@ -55,9 +102,12 @@ python run_language_modeling.py \
--model_name_or_path=roberta-base \ --model_name_or_path=roberta-base \
--do_train \ --do_train \
--train_data_file=$TRAIN_FILE \ --train_data_file=$TRAIN_FILE \
--chinese_ref_file=$REF_FILE \
--do_eval \ --do_eval \
--eval_data_file=$TEST_FILE \ --eval_data_file=$TEST_FILE \
--mlm --mlm \
--line_by_line \
--wwm
``` ```
### XLNet and permutation language modeling ### XLNet and permutation language modeling
......
import argparse
import json
from typing import List
from ltp import LTP
from transformers.tokenization_bert import BertTokenizer
def _is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
def is_chinese(word: str):
# word like '180' or '身高' or '神'
for char in word:
char = ord(char)
if not _is_chinese_char(char):
return 0
return 1
def get_chinese_word(tokens: List[str]):
word_set = set()
for token in tokens:
chinese_word = len(token) > 1 and is_chinese(token)
if chinese_word:
word_set.add(token)
word_list = list(word_set)
return word_list
def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()):
if not chinese_word_set:
return bert_tokens
max_word_len = max([len(w) for w in chinese_word_set])
bert_word = bert_tokens
start, end = 0, len(bert_word)
while start < end:
single_word = True
if is_chinese(bert_word[start]):
l = min(end - start, max_word_len)
for i in range(l, 1, -1):
whole_word = "".join(bert_word[start : start + i])
if whole_word in chinese_word_set:
for j in range(start + 1, start + i):
bert_word[j] = "##" + bert_word[j]
start = start + i
single_word = False
break
if single_word:
start += 1
return bert_word
def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokenizer):
ltp_res = []
for i in range(0, len(lines), 100):
res = ltp_tokenizer.seg(lines[i : i + 100])[0]
res = [get_chinese_word(r) for r in res]
ltp_res.extend(res)
assert len(ltp_res) == len(lines)
bert_res = []
for i in range(0, len(lines), 100):
res = bert_tokenizer(lines[i : i + 100], add_special_tokens=True, truncation=True, max_length=512)
bert_res.extend(res["input_ids"])
assert len(bert_res) == len(lines)
ref_ids = []
for input_ids, chinese_word in zip(bert_res, ltp_res):
input_tokens = []
for id in input_ids:
token = bert_tokenizer._convert_id_to_token(id)
input_tokens.append(token)
input_tokens = add_sub_symbol(input_tokens, chinese_word)
ref_id = []
# We only save pos of chinese subwords start with ##, which mean is part of a whole word.
for i, token in enumerate(input_tokens):
if token[:2] == "##":
clean_token = token[2:]
# save chinese tokens' pos
if len(clean_token) == 1 and _is_chinese_char(ord(clean_token)):
ref_id.append(i)
ref_ids.append(ref_id)
assert len(ref_ids) == len(bert_res)
return ref_ids
def main(args):
# For Chinese (Ro)Bert, the best result is from : RoBERTa-wwm-ext (https://github.com/ymcui/Chinese-BERT-wwm)
# If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp)
with open(args.file_name, "r", encoding="utf-8") as f:
data = f.readlines()
ltp_tokenizer = LTP(args.ltp) # faster in GPU device
bert_tokenizer = BertTokenizer.from_pretrained(args.bert)
ref_ids = prepare_ref(data, ltp_tokenizer, bert_tokenizer)
with open(args.save_path, "w", encoding="utf-8") as f:
data = [json.dumps(ref) + "\n" for ref in ref_ids]
f.writelines(data)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="prepare_chinese_ref")
parser.add_argument(
"--file_name",
type=str,
default="./resources/chinese-demo.txt",
help="file need process, same as training data in lm",
)
parser.add_argument(
"--ltp", type=str, default="./resources/ltp", help="resources for LTP tokenizer, usually a path"
)
parser.add_argument("--bert", type=str, default="./resources/robert", help="resources for Bert tokenizer")
parser.add_argument("--save_path", type=str, default="./resources/ref.txt", help="path to save res")
args = parser.parse_args()
main(args)
...@@ -37,8 +37,10 @@ from transformers import ( ...@@ -37,8 +37,10 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
DataCollatorForLanguageModeling, DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling, DataCollatorForPermutationLanguageModeling,
DataCollatorForWholeWordMask,
HfArgumentParser, HfArgumentParser,
LineByLineTextDataset, LineByLineTextDataset,
LineByLineWithRefDataset,
PreTrainedTokenizer, PreTrainedTokenizer,
TextDataset, TextDataset,
Trainer, Trainer,
...@@ -101,6 +103,10 @@ class DataTrainingArguments: ...@@ -101,6 +103,10 @@ class DataTrainingArguments:
default=None, default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
) )
chinese_ref_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input ref data file for whole word mask in Chinees."},
)
line_by_line: bool = field( line_by_line: bool = field(
default=False, default=False,
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
...@@ -109,6 +115,7 @@ class DataTrainingArguments: ...@@ -109,6 +115,7 @@ class DataTrainingArguments:
mlm: bool = field( mlm: bool = field(
default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."} default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."}
) )
whole_word_mask: bool = field(default=False, metadata={"help": "Whether ot not to use whole word mask."})
mlm_probability: float = field( mlm_probability: float = field(
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
) )
...@@ -143,6 +150,16 @@ def get_dataset( ...@@ -143,6 +150,16 @@ def get_dataset(
): ):
def _dataset(file_path): def _dataset(file_path):
if args.line_by_line: if args.line_by_line:
if args.chinese_ref_file is not None:
if not args.whole_word_mask or not args.mlm:
raise ValueError("You need to set world whole masking and mlm to True for Chinese Whole Word Mask")
return LineByLineWithRefDataset(
tokenizer=tokenizer,
file_path=file_path,
block_size=args.block_size,
ref_path=args.chinese_ref_file,
)
return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size)
else: else:
return TextDataset( return TextDataset(
...@@ -174,7 +191,6 @@ def main(): ...@@ -174,7 +191,6 @@ def main():
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
"or remove the --do_eval argument." "or remove the --do_eval argument."
) )
if ( if (
os.path.exists(training_args.output_dir) os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir) and os.listdir(training_args.output_dir)
...@@ -270,9 +286,14 @@ def main(): ...@@ -270,9 +286,14 @@ def main():
max_span_length=data_args.max_span_length, max_span_length=data_args.max_span_length,
) )
else: else:
data_collator = DataCollatorForLanguageModeling( if data_args.mlm and data_args.whole_word_mask:
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability data_collator = DataCollatorForWholeWordMask(
) tokenizer=tokenizer, mlm_probability=data_args.mlm_probability
)
else:
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
)
# Initialize our Trainer # Initialize our Trainer
trainer = Trainer( trainer = Trainer(
......
...@@ -284,6 +284,7 @@ if is_torch_available(): ...@@ -284,6 +284,7 @@ if is_torch_available():
DataCollatorForNextSentencePrediction, DataCollatorForNextSentencePrediction,
DataCollatorForPermutationLanguageModeling, DataCollatorForPermutationLanguageModeling,
DataCollatorForSOP, DataCollatorForSOP,
DataCollatorForWholeWordMask,
DataCollatorWithPadding, DataCollatorWithPadding,
default_data_collator, default_data_collator,
) )
...@@ -291,6 +292,7 @@ if is_torch_available(): ...@@ -291,6 +292,7 @@ if is_torch_available():
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
LineByLineTextDataset, LineByLineTextDataset,
LineByLineWithRefDataset,
LineByLineWithSOPTextDataset, LineByLineWithSOPTextDataset,
SquadDataset, SquadDataset,
SquadDataTrainingArguments, SquadDataTrainingArguments,
......
import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
...@@ -195,6 +196,124 @@ class DataCollatorForLanguageModeling: ...@@ -195,6 +196,124 @@ class DataCollatorForLanguageModeling:
return inputs, labels return inputs, labels
@dataclass
class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
"""
Data collator used for language modeling.
- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for masked language modeling
"""
def __call__(
self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
if isinstance(examples[0], (dict, BatchEncoding)):
input_ids = [e["input_ids"] for e in examples]
else:
input_ids = examples
examples = [{"input_ids": e} for e in examples]
batch_input = self._tensorize_batch(input_ids)
mask_labels = []
for e in examples:
ref_tokens = []
for id in e["input_ids"].tolist():
token = self.tokenizer._convert_id_to_token(id)
ref_tokens.append(token)
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
if "chinese_ref" in e:
ref_pos = e["chinese_ref"].tolist()
len_seq = e["input_ids"].size(0)
for i in range(len_seq):
if i in ref_pos:
ref_tokens[i] = "##" + ref_tokens[i]
mask_labels.append(self._whole_word_mask(ref_tokens))
batch_mask = self._tensorize_batch(mask_labels)
inputs, labels = self.mask_tokens(batch_input, batch_mask)
return {"input_ids": inputs, "labels": labels}
def _whole_word_mask(self, input_tokens: List[str], max_predictions=512):
"""
Get 0/1 labels for masked tokens with whole word mask proxy
"""
cand_indexes = []
for (i, token) in enumerate(input_tokens):
if token == "[CLS]" or token == "[SEP]":
continue
if len(cand_indexes) >= 1 and token.startswith("##"):
cand_indexes[-1].append(i)
else:
cand_indexes.append([i])
random.shuffle(cand_indexes)
num_to_predict = min(max_predictions, max(1, int(round(len(input_tokens) * self.mlm_probability))))
masked_lms = []
covered_indexes = set()
for index_set in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if len(masked_lms) + len(index_set) > num_to_predict:
continue
is_any_index_covered = False
for index in index_set:
if index in covered_indexes:
is_any_index_covered = True
break
if is_any_index_covered:
continue
for index in index_set:
covered_indexes.add(index)
masked_lms.append(index)
assert len(covered_indexes) == len(masked_lms)
mask_labels = [1 if i in covered_indexes else 0 for i in range(len(input_tokens))]
return mask_labels
def mask_tokens(self, inputs: torch.Tensor, mask_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
Set 'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
"""
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = mask_labels
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = probability_matrix.bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
@dataclass @dataclass
class DataCollatorForSOP(DataCollatorForLanguageModeling): class DataCollatorForSOP(DataCollatorForLanguageModeling):
""" """
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from .glue import GlueDataset, GlueDataTrainingArguments from .glue import GlueDataset, GlueDataTrainingArguments
from .language_modeling import ( from .language_modeling import (
LineByLineTextDataset, LineByLineTextDataset,
LineByLineWithRefDataset,
LineByLineWithSOPTextDataset, LineByLineWithSOPTextDataset,
TextDataset, TextDataset,
TextDatasetForNextSentencePrediction, TextDatasetForNextSentencePrediction,
......
import json
import os import os
import pickle import pickle
import random import random
...@@ -106,12 +107,48 @@ class LineByLineTextDataset(Dataset): ...@@ -106,12 +107,48 @@ class LineByLineTextDataset(Dataset):
batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size) batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)
self.examples = batch_encoding["input_ids"] self.examples = batch_encoding["input_ids"]
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
def __len__(self): def __len__(self):
return len(self.examples) return len(self.examples)
def __getitem__(self, i) -> torch.Tensor: def __getitem__(self, i) -> Dict[str, torch.tensor]:
return torch.tensor(self.examples[i], dtype=torch.long) return self.examples[i]
class LineByLineWithRefDataset(Dataset):
"""
This will be superseded by a framework-agnostic approach
soon.
"""
def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, ref_path: str):
assert os.path.isfile(file_path), f"Input file path {file_path} not found"
assert os.path.isfile(ref_path), f"Ref file path {file_path} not found"
# Here, we do not cache the features, operating under the assumption
# that we will soon use fast multithreaded tokenizers from the
# `tokenizers` repo everywhere =)
logger.info("Creating features from dataset file at %s", file_path)
logger.info("Use ref segment results at %s", ref_path)
with open(file_path, encoding="utf-8") as f:
data = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
batch_encoding = tokenizer(data, add_special_tokens=True, truncation=True, max_length=block_size)
self.examples = batch_encoding["input_ids"]
self.examples = [{"input_ids": torch.tensor(e, dtype=torch.long)} for e in self.examples]
# Get ref inf from file
with open(ref_path, encoding="utf-8") as f:
ref = [json.loads(line) for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
assert len(data) == len(ref)
n = len(self.examples)
for i in range(n):
self.examples[i]["chinese_ref"] = torch.tensor(ref[i], dtype=torch.long)
def __len__(self):
return len(self.examples)
def __getitem__(self, i) -> Dict[str, torch.tensor]:
return self.examples[i]
class LineByLineWithSOPTextDataset(Dataset): class LineByLineWithSOPTextDataset(Dataset):
......
...@@ -45,6 +45,11 @@ class DataCollatorForSOP: ...@@ -45,6 +45,11 @@ class DataCollatorForSOP:
requires_pytorch(self) requires_pytorch(self)
class DataCollatorForWholeWordMask:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class DataCollatorWithPadding: class DataCollatorWithPadding:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
...@@ -69,6 +74,11 @@ class LineByLineTextDataset: ...@@ -69,6 +74,11 @@ class LineByLineTextDataset:
requires_pytorch(self) requires_pytorch(self)
class LineByLineWithRefDataset:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class LineByLineWithSOPTextDataset: class LineByLineWithSOPTextDataset:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment